Benchmarks: Add Benchmark - Add LSTM model benchmarks. (#60)

* Benchmarks: Add Benchmark - Add LSTM model benchmarks.
This commit is contained in:
guoshzhao 2021-04-20 10:53:44 +08:00 коммит произвёл GitHub
Родитель 902ea211d1
Коммит 2a7ab691f1
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 321 добавлений и 7 удалений

Просмотреть файл

@ -0,0 +1,41 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Model benchmark example for lstm (8-layer, 1024-hidden, 256-input_size, False-bidirectional).
Commands to run:
python3 examples/benchmarks/pytorch_lstm.py (Single GPU)
python3 -m torch.distributed.launch --use_env --nproc_per_node=8 examples/benchmarks/pytorch_lstm.py \
--distributed (Distributed)
"""
import argparse
from superbench.benchmarks import Platform, Framework, BenchmarkRegistry
from superbench.common.utils import logger
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--distributed', action='store_true', default=False, help='Whether to enable distributed training.'
)
args = parser.parse_args()
# Specify the model name and benchmark parameters.
model_name = 'lstm'
parameters = '--batch_size 1 --seq_len 256 --precision float32 --num_warmup 8 --num_steps 64 --run_count 2'
if args.distributed:
parameters += ' --distributed_impl ddp --distributed_backend nccl'
# Create context for lstm benchmark and run it for 64 steps.
context = BenchmarkRegistry.create_benchmark_context(
model_name, platform=Platform.CUDA, parameters=parameters, framework=Framework.PYTORCH
)
benchmark = BenchmarkRegistry.launch_benchmark(context)
if benchmark:
logger.info(
'benchmark: {}, return code: {}, result: {}'.format(
benchmark.name, benchmark.return_code, benchmark.result
)
)

Просмотреть файл

@ -7,5 +7,6 @@ from superbench.benchmarks.model_benchmarks.model_base import ModelBenchmark
from superbench.benchmarks.model_benchmarks.pytorch_bert import PytorchBERT from superbench.benchmarks.model_benchmarks.pytorch_bert import PytorchBERT
from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import PytorchGPT2 from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import PytorchGPT2
from superbench.benchmarks.model_benchmarks.pytorch_cnn import PytorchCNN from superbench.benchmarks.model_benchmarks.pytorch_cnn import PytorchCNN
from superbench.benchmarks.model_benchmarks.pytorch_lstm import PytorchLSTM
__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN'] __all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM']

Просмотреть файл

@ -17,16 +17,16 @@ from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDat
class BertBenchmarkModel(torch.nn.Module): class BertBenchmarkModel(torch.nn.Module):
"""The BERT model for benchmarking.""" """The BERT model for benchmarking."""
def __init__(self, config, num_class): def __init__(self, config, num_classes):
"""Constructor. """Constructor.
Args: Args:
config (BertConfig): Configurations of BERT model. config (BertConfig): Configurations of BERT model.
num_class (int): The number of objects for classification. num_classes (int): The number of objects for classification.
""" """
super().__init__() super().__init__()
self._bert = BertModel(config) self._bert = BertModel(config)
self._linear = torch.nn.Linear(config.hidden_size, num_class) self._linear = torch.nn.Linear(config.hidden_size, num_classes)
def forward(self, input): def forward(self, input):
"""Forward propagation function. """Forward propagation function.

Просмотреть файл

@ -17,16 +17,16 @@ from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDat
class GPT2BenchmarkModel(torch.nn.Module): class GPT2BenchmarkModel(torch.nn.Module):
"""The GPT2 model for benchmarking.""" """The GPT2 model for benchmarking."""
def __init__(self, config, num_class): def __init__(self, config, num_classes):
"""Constructor. """Constructor.
Args: Args:
config (GPT2Config): Configurations of GPT2 model. config (GPT2Config): Configurations of GPT2 model.
num_class (int): The number of objects for classification. num_classes (int): The number of objects for classification.
""" """
super().__init__() super().__init__()
self._bert = GPT2Model(config) self._bert = GPT2Model(config)
self._linear = torch.nn.Linear(config.hidden_size, num_class) self._linear = torch.nn.Linear(config.hidden_size, num_classes)
def forward(self, input): def forward(self, input):
"""Forward propagation function. """Forward propagation function.

Просмотреть файл

@ -0,0 +1,196 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Module of the Pytorch LSTM model."""
import time
import torch
from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Precision
from superbench.benchmarks.model_benchmarks.model_base import Optimizer
from superbench.benchmarks.model_benchmarks.pytorch_base import PytorchBase
from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDataset
class LSTMBenchmarkModel(torch.nn.Module):
"""The LSTM model for benchmarking."""
def __init__(self, input_size, hidden_size, num_layers, bidirectional, num_classes):
"""Constructor.
Args:
input_size (int): The number of expected features in the input.
hidden_size (int): The number of features in the hidden state.
num_layers (int): The number of recurrent layers.
bidirectional (bool): If True, becomes a bidirectional LSTM.
num_classes (int): The number of objects for classification.
"""
super().__init__()
self._lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional)
self._linear = torch.nn.Linear(hidden_size, num_classes)
def forward(self, input):
"""Forward propagation function.
Args:
input (torch.FloatTensor): Tensor containing the features of the input sequence,
shape (sequence_length, batch_size, input_size).
Return:
result (torch.FloatTensor): The output features from the last layer of the LSTM
further processed by a Linear layer, shape (batch_size, num_classes).
"""
self._lstm.flatten_parameters()
outputs = self._lstm(input)
result = self._linear(outputs[0][:, -1, :])
return result
class PytorchLSTM(PytorchBase):
"""The LSTM benchmark class."""
def __init__(self, name, parameters=''):
"""Constructor.
Args:
name (str): benchmark name.
parameters (str): benchmark parameters.
"""
super().__init__(name, parameters)
self._config = None
self._supported_precision = [Precision.FLOAT32, Precision.FLOAT16]
self._optimizer_type = Optimizer.SGD
self._loss_fn = torch.nn.CrossEntropyLoss()
def add_parser_arguments(self):
"""Add the LSTM-specified arguments.
LSTM model reference: https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
"""
super().add_parser_arguments()
self._parser.add_argument(
'--num_classes', type=int, default=100, required=False, help='The number of objects for classification.'
)
self._parser.add_argument(
'--input_size', type=int, default=256, required=False, help='The number of expected features in the input.'
)
self._parser.add_argument(
'--hidden_size', type=int, default=1024, required=False, help='The number of features in the hidden state.'
)
self._parser.add_argument(
'--num_layers', type=int, default=8, required=False, help='The number of recurrent layers.'
)
self._parser.add_argument('--bidirectional', action='store_true', default=False, help='Bidirectional LSTM.')
self._parser.add_argument('--seq_len', type=int, default=512, required=False, help='Sequence length.')
def _generate_dataset(self):
"""Generate dataset for benchmarking according to shape info.
Return:
True if dataset is created successfully.
"""
self._dataset = TorchRandomDataset(
[self._args.sample_count, self._args.seq_len, self._args.input_size], self._world_size, dtype=torch.float
)
if len(self._dataset) == 0:
logger.error('Generate random dataset failed - model: {}'.format(self._name))
return False
return True
def _create_model(self, precision):
"""Construct the model for benchmarking.
Args:
precision (Precision): precision of model and input data, such as float32, float16.
"""
try:
self._model = LSTMBenchmarkModel(
self._args.input_size, self._args.hidden_size, self._args.num_layers, self._args.bidirectional,
self._args.num_classes
)
self._model = self._model.to(dtype=getattr(torch, precision.value))
if self._gpu_available:
self._model = self._model.cuda()
except BaseException as e:
logger.error(
'Create model with specified precision failed - model: {}, precision: {}, message: {}.'.format(
self._name, precision, str(e)
)
)
return False
self._target = torch.LongTensor(self._args.batch_size).random_(self._args.num_classes)
if self._gpu_available:
self._target = self._target.cuda()
return True
def _train_step(self, precision):
"""Define the training process.
Args:
precision (Precision): precision of model and input data, such as float32, float16.
Return:
The step-time list of every training step.
"""
duration = []
curr_step = 0
while True:
for idx, sample in enumerate(self._dataloader):
start = time.time()
sample = sample.to(dtype=getattr(torch, precision.value))
if self._gpu_available:
sample = sample.cuda()
self._optimizer.zero_grad()
output = self._model(sample)
loss = self._loss_fn(output, self._target)
loss.backward()
self._optimizer.step()
end = time.time()
curr_step += 1
if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond.
duration.append((end - start) * 1000)
if self._is_finished(curr_step, end):
return duration
def _inference_step(self, precision):
"""Define the inference process.
Args:
precision (Precision): precision of model and input data,
such as float32, float16.
Return:
The latency list of every inference operation.
"""
duration = []
curr_step = 0
with torch.no_grad():
self._model.eval()
while True:
for idx, sample in enumerate(self._dataloader):
start = time.time()
sample = sample.to(dtype=getattr(torch, precision.value))
if self._gpu_available:
sample = sample.cuda()
self._model(sample)
if self._gpu_available:
torch.cuda.synchronize()
end = time.time()
curr_step += 1
if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond.
duration.append((end - start) * 1000)
if self._is_finished(curr_step, end):
return duration
# Register LSTM benchmark.
BenchmarkRegistry.register_benchmark(
'pytorch-lstm', PytorchLSTM, parameters='--input_size=256 --hidden_size=1024 --num_layers=8'
)

Просмотреть файл

@ -0,0 +1,76 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Tests for LSTM model benchmarks."""
from tests.helper import decorator
from superbench.benchmarks import BenchmarkRegistry, Platform, Framework, BenchmarkType, ReturnCode
from superbench.benchmarks.model_benchmarks.pytorch_lstm import PytorchLSTM
@decorator.cuda_test
@decorator.pytorch_test
def test_pytorch_lstm_with_gpu():
"""Test pytorch-lstm benchmark with GPU."""
run_pytorch_lstm(
parameters='--batch_size 1 --num_classes 5 --seq_len 8 --num_warmup 2 --num_steps 4 \
--model_action train inference',
check_metrics=[
'steptime_train_float32', 'throughput_train_float32', 'steptime_train_float16', 'throughput_train_float16',
'steptime_inference_float32', 'throughput_inference_float32', 'steptime_inference_float16',
'throughput_inference_float16'
]
)
@decorator.pytorch_test
def test_pytorch_lstm_no_gpu():
"""Test pytorch-lstm benchmark with CPU."""
run_pytorch_lstm(
parameters='--batch_size 1 --num_classes 5 --seq_len 8 --num_warmup 2 --num_steps 4 \
--model_action train inference --precision float32 --no_gpu',
check_metrics=[
'steptime_train_float32', 'throughput_train_float32', 'steptime_inference_float32',
'throughput_inference_float32'
]
)
def run_pytorch_lstm(parameters='', check_metrics=[]):
"""Test pytorch-lstm benchmark."""
context = BenchmarkRegistry.create_benchmark_context(
'lstm', platform=Platform.CUDA, parameters=parameters, framework=Framework.PYTORCH
)
assert (BenchmarkRegistry.is_benchmark_context_valid(context))
benchmark = BenchmarkRegistry.launch_benchmark(context)
# Check basic information.
assert (benchmark)
assert (isinstance(benchmark, PytorchLSTM))
assert (benchmark.name == 'pytorch-lstm')
assert (benchmark.type == BenchmarkType.MODEL)
# Check predefined parameters of lstm model.
assert (benchmark._args.input_size == 256)
assert (benchmark._args.hidden_size == 1024)
assert (benchmark._args.num_layers == 8)
# Check parameters specified in BenchmarkContext.
assert (benchmark._args.batch_size == 1)
assert (benchmark._args.num_classes == 5)
assert (benchmark._args.seq_len == 8)
assert (benchmark._args.num_warmup == 2)
assert (benchmark._args.num_steps == 4)
# Check dataset scale.
assert (len(benchmark._dataset) == benchmark._args.sample_count * benchmark._world_size)
# Check results and metrics.
assert (benchmark.run_count == 1)
assert (benchmark.return_code == ReturnCode.SUCCESS)
for metric in check_metrics:
assert (len(benchmark.raw_data[metric]) == benchmark.run_count)
assert (len(benchmark.raw_data[metric][0]) == benchmark._args.num_steps)
assert (len(benchmark.result[metric]) == benchmark.run_count)