Benchmarks: Add Benchmark - Add GPT2 model benchmark. (#57)
* Benchmarks: Add Benchmark - Add GPT2 model benchmark.
This commit is contained in:
Родитель
fb850af760
Коммит
af567cf650
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""Model benchmark example for gpt2-large (36-layer, 1280-hidden, 20-heads, 774M parameters).
|
||||
|
||||
Commands to run:
|
||||
python3 examples/benchmarks/pytorch_gpt2_large.py (Single GPU)
|
||||
python3 -m torch.distributed.launch --use_env --nproc_per_node=8 examples/benchmarks/pytorch_gpt2_large.py \
|
||||
--distributed (Distributed)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from superbench.benchmarks import Platform, Framework, BenchmarkRegistry
|
||||
from superbench.common.utils import logger
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--distributed', action='store_true', default=False, help='Whether to enable distributed training.'
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Specify the model name and benchmark parameters.
|
||||
model_name = 'gpt2-large'
|
||||
parameters = '--batch_size 1 --duration 120 --seq_len 128 --precision float32 --run_count 2'
|
||||
if args.distributed:
|
||||
parameters += ' --distributed_impl ddp --distributed_backend nccl'
|
||||
|
||||
# Create context for gpt2-large benchmark and run it for 120 * 2 seconds.
|
||||
context = BenchmarkRegistry.create_benchmark_context(
|
||||
model_name, platform=Platform.CUDA, parameters=parameters, framework=Framework.PYTORCH
|
||||
)
|
||||
|
||||
benchmark = BenchmarkRegistry.launch_benchmark(context)
|
||||
if benchmark:
|
||||
logger.info(
|
||||
'benchmark: {}, return code: {}, result: {}'.format(
|
||||
benchmark.name, benchmark.return_code, benchmark.result
|
||||
)
|
||||
)
|
|
@ -5,5 +5,6 @@
|
|||
|
||||
from superbench.benchmarks.model_benchmarks.model_base import ModelBenchmark
|
||||
from superbench.benchmarks.model_benchmarks.pytorch_bert import PytorchBERT
|
||||
from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import PytorchGPT2
|
||||
|
||||
__all__ = ['ModelBenchmark', 'PytorchBERT']
|
||||
__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2']
|
||||
|
|
|
@ -0,0 +1,205 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""Module of the Pytorch GPT2 model."""
|
||||
|
||||
import time
|
||||
|
||||
import torch
|
||||
from transformers import GPT2Model, GPT2Config
|
||||
|
||||
from superbench.common.utils import logger
|
||||
from superbench.benchmarks import BenchmarkRegistry, Precision
|
||||
from superbench.benchmarks.model_benchmarks.model_base import Optimizer
|
||||
from superbench.benchmarks.model_benchmarks.pytorch_base import PytorchBase
|
||||
from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDataset
|
||||
|
||||
|
||||
class GPT2BenchmarkModel(torch.nn.Module):
|
||||
"""The GPT2 model for benchmarking."""
|
||||
def __init__(self, config, num_class):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
config (GPT2Config): Configurations of GPT2 model.
|
||||
num_class (int): The number of objects for classification.
|
||||
"""
|
||||
super().__init__()
|
||||
self._bert = GPT2Model(config)
|
||||
self._linear = torch.nn.Linear(config.hidden_size, num_class)
|
||||
|
||||
def forward(self, input):
|
||||
"""Forward propagation function.
|
||||
|
||||
Args:
|
||||
input (torch.LongTensor): Indices of input sequence tokens in the vocabulary,
|
||||
shape (batch_size, sequence_length).
|
||||
|
||||
Return:
|
||||
result (torch.FloatTensor): Last layer hidden-state of the first token of the sequence
|
||||
(classification token) further processed by a Linear layer, shape (batch_size, hidden_size).
|
||||
"""
|
||||
outputs = self._bert(input)
|
||||
result = self._linear(outputs[0])
|
||||
return result
|
||||
|
||||
|
||||
class PytorchGPT2(PytorchBase):
|
||||
"""The GPT2 benchmark class."""
|
||||
def __init__(self, name, parameters=''):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
name (str): benchmark name.
|
||||
parameters (str): benchmark parameters.
|
||||
"""
|
||||
super().__init__(name, parameters)
|
||||
self._config = None
|
||||
self._supported_precision = [Precision.FLOAT32, Precision.FLOAT16]
|
||||
self._optimizer_type = Optimizer.ADAMW
|
||||
self._loss_fn = torch.nn.CrossEntropyLoss()
|
||||
|
||||
def add_parser_arguments(self):
|
||||
"""Add the GPT2-specified arguments.
|
||||
|
||||
GPT2 model reference: https://huggingface.co/transformers/model_doc/gpt2.html
|
||||
"""
|
||||
super().add_parser_arguments()
|
||||
|
||||
self._parser.add_argument('--num_classes', type=int, default=100, required=False, help='Num of class.')
|
||||
self._parser.add_argument('--hidden_size', type=int, default=1280, required=False, help='Hidden size.')
|
||||
self._parser.add_argument(
|
||||
'--num_hidden_layers', type=int, default=36, required=False, help='The number of hidden layers.'
|
||||
)
|
||||
self._parser.add_argument(
|
||||
'--num_attention_heads', type=int, default=20, required=False, help='The number of attention heads.'
|
||||
)
|
||||
self._parser.add_argument('--seq_len', type=int, default=512, required=False, help='Sequence length.')
|
||||
|
||||
def _generate_dataset(self):
|
||||
"""Generate dataset for benchmarking according to shape info.
|
||||
|
||||
Return:
|
||||
True if dataset is created successfully.
|
||||
"""
|
||||
self._dataset = TorchRandomDataset(
|
||||
[self._args.sample_count, self._args.seq_len], self._world_size, dtype=torch.long
|
||||
)
|
||||
if len(self._dataset) == 0:
|
||||
logger.error('Generate random dataset failed - model: {}'.format(self._name))
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _create_model(self, precision):
|
||||
"""Construct the model for benchmarking.
|
||||
|
||||
Args:
|
||||
precision (Precision): precision of model and input data, such as float32, float16.
|
||||
"""
|
||||
self._config = GPT2Config(
|
||||
n_embd=self._args.hidden_size, n_layer=self._args.num_hidden_layers, n_head=self._args.num_attention_heads
|
||||
)
|
||||
|
||||
try:
|
||||
self._model = GPT2BenchmarkModel(self._config, self._args.num_classes)
|
||||
self._model = self._model.to(dtype=getattr(torch, precision.value))
|
||||
if self._gpu_available:
|
||||
self._model = self._model.cuda()
|
||||
except BaseException as e:
|
||||
logger.error(
|
||||
'Create model with specified precision failed - model: {}, precision: {}, message: {}.'.format(
|
||||
self._name, precision, str(e)
|
||||
)
|
||||
)
|
||||
return False
|
||||
|
||||
self._target = torch.LongTensor(self._args.batch_size).random_(self._args.num_classes)
|
||||
if self._gpu_available:
|
||||
self._target = self._target.cuda()
|
||||
|
||||
return True
|
||||
|
||||
def _train_step(self, precision):
|
||||
"""Define the training process.
|
||||
|
||||
Args:
|
||||
precision (Precision): precision of model and input data, such as float32, float16.
|
||||
|
||||
Return:
|
||||
The step-time list of every training step.
|
||||
"""
|
||||
duration = []
|
||||
curr_step = 0
|
||||
while True:
|
||||
for idx, sample in enumerate(self._dataloader):
|
||||
start = time.time()
|
||||
if self._gpu_available:
|
||||
sample = sample.cuda()
|
||||
self._optimizer.zero_grad()
|
||||
output = self._model(sample)
|
||||
loss = self._loss_fn(output[range(self._args.batch_size), -1], self._target)
|
||||
loss.backward()
|
||||
self._optimizer.step()
|
||||
end = time.time()
|
||||
curr_step += 1
|
||||
if curr_step > self._args.num_warmup:
|
||||
# Save the step time of every training/inference step, unit is millisecond.
|
||||
duration.append((end - start) * 1000)
|
||||
if self._is_finished(curr_step, end):
|
||||
return duration
|
||||
|
||||
def _inference_step(self, precision):
|
||||
"""Define the inference process.
|
||||
|
||||
Args:
|
||||
precision (Precision): precision of model and input data,
|
||||
such as float32, float16.
|
||||
|
||||
Return:
|
||||
The latency list of every inference operation.
|
||||
"""
|
||||
duration = []
|
||||
curr_step = 0
|
||||
with torch.no_grad():
|
||||
self._model.eval()
|
||||
while True:
|
||||
for idx, sample in enumerate(self._dataloader):
|
||||
start = time.time()
|
||||
if self._gpu_available:
|
||||
sample = sample.cuda()
|
||||
self._model(sample)
|
||||
if self._gpu_available:
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
curr_step += 1
|
||||
if curr_step > self._args.num_warmup:
|
||||
# Save the step time of every training/inference step, unit is millisecond.
|
||||
duration.append((end - start) * 1000)
|
||||
if self._is_finished(curr_step, end):
|
||||
return duration
|
||||
|
||||
|
||||
# Register GPT2 benchmark with 117M parameters.
|
||||
# Reference: https://huggingface.co/transformers/pretrained_models.html
|
||||
BenchmarkRegistry.register_benchmark(
|
||||
'pytorch-gpt2-small', PytorchGPT2, parameters='--hidden_size=768 --num_hidden_layers=12 --num_attention_heads=12'
|
||||
)
|
||||
|
||||
# Register GPT2 benchmark with 345M parameters.
|
||||
# Reference: https://huggingface.co/transformers/pretrained_models.html
|
||||
BenchmarkRegistry.register_benchmark(
|
||||
'pytorch-gpt2-medium', PytorchGPT2, parameters='--hidden_size=1024 --num_hidden_layers=24 --num_attention_heads=16'
|
||||
)
|
||||
|
||||
# Register GPT2 benchmark with 774M parameters.
|
||||
# Reference: https://huggingface.co/transformers/pretrained_models.html
|
||||
BenchmarkRegistry.register_benchmark(
|
||||
'pytorch-gpt2-large', PytorchGPT2, parameters='--hidden_size=1280 --num_hidden_layers=36 --num_attention_heads=20'
|
||||
)
|
||||
|
||||
# Register GPT2 benchmark with 1558M parameters.
|
||||
# Reference: https://huggingface.co/transformers/pretrained_models.html
|
||||
BenchmarkRegistry.register_benchmark(
|
||||
'pytorch-gpt2-xl', PytorchGPT2, parameters='--hidden_size=1600 --num_hidden_layers=48 --num_attention_heads=25'
|
||||
)
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Tests for GPT2 model benchmarks."""
|
||||
|
||||
from tests.helper import decorator
|
||||
from superbench.benchmarks import BenchmarkRegistry, Platform, Framework, BenchmarkType, ReturnCode
|
||||
from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import PytorchGPT2
|
||||
|
||||
|
||||
@decorator.cuda_test
|
||||
@decorator.pytorch_test
|
||||
def test_pytorch_gpt2_small():
|
||||
"""Test pytorch-gpt2-small benchmark."""
|
||||
context = BenchmarkRegistry.create_benchmark_context(
|
||||
'gpt2-small',
|
||||
platform=Platform.CUDA,
|
||||
parameters='--batch_size 1 --num_classes 5 --seq_len 8 --num_warmup 2 --num_steps 4 \
|
||||
--model_action train inference',
|
||||
framework=Framework.PYTORCH
|
||||
)
|
||||
|
||||
assert (BenchmarkRegistry.is_benchmark_context_valid(context))
|
||||
|
||||
benchmark = BenchmarkRegistry.launch_benchmark(context)
|
||||
|
||||
# Check basic information.
|
||||
assert (benchmark)
|
||||
assert (isinstance(benchmark, PytorchGPT2))
|
||||
assert (benchmark.name == 'pytorch-gpt2-small')
|
||||
assert (benchmark.type == BenchmarkType.MODEL)
|
||||
|
||||
# Check predefined parameters of gpt2-large model.
|
||||
assert (benchmark._args.hidden_size == 768)
|
||||
assert (benchmark._args.num_hidden_layers == 12)
|
||||
assert (benchmark._args.num_attention_heads == 12)
|
||||
|
||||
# Check parameters specified in BenchmarkContext.
|
||||
assert (benchmark._args.batch_size == 1)
|
||||
assert (benchmark._args.num_classes == 5)
|
||||
assert (benchmark._args.seq_len == 8)
|
||||
assert (benchmark._args.num_warmup == 2)
|
||||
assert (benchmark._args.num_steps == 4)
|
||||
|
||||
# Test Dataset.
|
||||
assert (len(benchmark._dataset) == benchmark._args.sample_count * benchmark._world_size)
|
||||
|
||||
# Check results and metrics.
|
||||
assert (benchmark.run_count == 1)
|
||||
assert (benchmark.return_code == ReturnCode.SUCCESS)
|
||||
for metric in [
|
||||
'steptime_train_float32', 'throughput_train_float32', 'steptime_train_float16', 'throughput_train_float16',
|
||||
'steptime_inference_float32', 'throughput_inference_float32', 'steptime_inference_float16',
|
||||
'throughput_inference_float16'
|
||||
]:
|
||||
assert (len(benchmark.raw_data[metric]) == benchmark.run_count)
|
||||
assert (len(benchmark.raw_data[metric][0]) == benchmark._args.num_steps)
|
||||
assert (len(benchmark.result[metric]) == benchmark.run_count)
|
Загрузка…
Ссылка в новой задаче