Benchmarks: Add Benchmark - Add LSTM model benchmarks. (#60)
* Benchmarks: Add Benchmark - Add LSTM model benchmarks.
This commit is contained in:
Родитель
902ea211d1
Коммит
2a7ab691f1
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""Model benchmark example for lstm (8-layer, 1024-hidden, 256-input_size, False-bidirectional).
|
||||
|
||||
Commands to run:
|
||||
python3 examples/benchmarks/pytorch_lstm.py (Single GPU)
|
||||
python3 -m torch.distributed.launch --use_env --nproc_per_node=8 examples/benchmarks/pytorch_lstm.py \
|
||||
--distributed (Distributed)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from superbench.benchmarks import Platform, Framework, BenchmarkRegistry
|
||||
from superbench.common.utils import logger
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--distributed', action='store_true', default=False, help='Whether to enable distributed training.'
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Specify the model name and benchmark parameters.
|
||||
model_name = 'lstm'
|
||||
parameters = '--batch_size 1 --seq_len 256 --precision float32 --num_warmup 8 --num_steps 64 --run_count 2'
|
||||
if args.distributed:
|
||||
parameters += ' --distributed_impl ddp --distributed_backend nccl'
|
||||
|
||||
# Create context for lstm benchmark and run it for 64 steps.
|
||||
context = BenchmarkRegistry.create_benchmark_context(
|
||||
model_name, platform=Platform.CUDA, parameters=parameters, framework=Framework.PYTORCH
|
||||
)
|
||||
|
||||
benchmark = BenchmarkRegistry.launch_benchmark(context)
|
||||
if benchmark:
|
||||
logger.info(
|
||||
'benchmark: {}, return code: {}, result: {}'.format(
|
||||
benchmark.name, benchmark.return_code, benchmark.result
|
||||
)
|
||||
)
|
|
@ -7,5 +7,6 @@ from superbench.benchmarks.model_benchmarks.model_base import ModelBenchmark
|
|||
from superbench.benchmarks.model_benchmarks.pytorch_bert import PytorchBERT
|
||||
from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import PytorchGPT2
|
||||
from superbench.benchmarks.model_benchmarks.pytorch_cnn import PytorchCNN
|
||||
from superbench.benchmarks.model_benchmarks.pytorch_lstm import PytorchLSTM
|
||||
|
||||
__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN']
|
||||
__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM']
|
||||
|
|
|
@ -17,16 +17,16 @@ from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDat
|
|||
|
||||
class BertBenchmarkModel(torch.nn.Module):
|
||||
"""The BERT model for benchmarking."""
|
||||
def __init__(self, config, num_class):
|
||||
def __init__(self, config, num_classes):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
config (BertConfig): Configurations of BERT model.
|
||||
num_class (int): The number of objects for classification.
|
||||
num_classes (int): The number of objects for classification.
|
||||
"""
|
||||
super().__init__()
|
||||
self._bert = BertModel(config)
|
||||
self._linear = torch.nn.Linear(config.hidden_size, num_class)
|
||||
self._linear = torch.nn.Linear(config.hidden_size, num_classes)
|
||||
|
||||
def forward(self, input):
|
||||
"""Forward propagation function.
|
||||
|
|
|
@ -17,16 +17,16 @@ from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDat
|
|||
|
||||
class GPT2BenchmarkModel(torch.nn.Module):
|
||||
"""The GPT2 model for benchmarking."""
|
||||
def __init__(self, config, num_class):
|
||||
def __init__(self, config, num_classes):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
config (GPT2Config): Configurations of GPT2 model.
|
||||
num_class (int): The number of objects for classification.
|
||||
num_classes (int): The number of objects for classification.
|
||||
"""
|
||||
super().__init__()
|
||||
self._bert = GPT2Model(config)
|
||||
self._linear = torch.nn.Linear(config.hidden_size, num_class)
|
||||
self._linear = torch.nn.Linear(config.hidden_size, num_classes)
|
||||
|
||||
def forward(self, input):
|
||||
"""Forward propagation function.
|
||||
|
|
|
@ -0,0 +1,196 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""Module of the Pytorch LSTM model."""
|
||||
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from superbench.common.utils import logger
|
||||
from superbench.benchmarks import BenchmarkRegistry, Precision
|
||||
from superbench.benchmarks.model_benchmarks.model_base import Optimizer
|
||||
from superbench.benchmarks.model_benchmarks.pytorch_base import PytorchBase
|
||||
from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDataset
|
||||
|
||||
|
||||
class LSTMBenchmarkModel(torch.nn.Module):
|
||||
"""The LSTM model for benchmarking."""
|
||||
def __init__(self, input_size, hidden_size, num_layers, bidirectional, num_classes):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
input_size (int): The number of expected features in the input.
|
||||
hidden_size (int): The number of features in the hidden state.
|
||||
num_layers (int): The number of recurrent layers.
|
||||
bidirectional (bool): If True, becomes a bidirectional LSTM.
|
||||
num_classes (int): The number of objects for classification.
|
||||
"""
|
||||
super().__init__()
|
||||
self._lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional)
|
||||
self._linear = torch.nn.Linear(hidden_size, num_classes)
|
||||
|
||||
def forward(self, input):
|
||||
"""Forward propagation function.
|
||||
|
||||
Args:
|
||||
input (torch.FloatTensor): Tensor containing the features of the input sequence,
|
||||
shape (sequence_length, batch_size, input_size).
|
||||
|
||||
Return:
|
||||
result (torch.FloatTensor): The output features from the last layer of the LSTM
|
||||
further processed by a Linear layer, shape (batch_size, num_classes).
|
||||
"""
|
||||
self._lstm.flatten_parameters()
|
||||
outputs = self._lstm(input)
|
||||
result = self._linear(outputs[0][:, -1, :])
|
||||
return result
|
||||
|
||||
|
||||
class PytorchLSTM(PytorchBase):
|
||||
"""The LSTM benchmark class."""
|
||||
def __init__(self, name, parameters=''):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
name (str): benchmark name.
|
||||
parameters (str): benchmark parameters.
|
||||
"""
|
||||
super().__init__(name, parameters)
|
||||
self._config = None
|
||||
self._supported_precision = [Precision.FLOAT32, Precision.FLOAT16]
|
||||
self._optimizer_type = Optimizer.SGD
|
||||
self._loss_fn = torch.nn.CrossEntropyLoss()
|
||||
|
||||
def add_parser_arguments(self):
|
||||
"""Add the LSTM-specified arguments.
|
||||
|
||||
LSTM model reference: https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
|
||||
"""
|
||||
super().add_parser_arguments()
|
||||
|
||||
self._parser.add_argument(
|
||||
'--num_classes', type=int, default=100, required=False, help='The number of objects for classification.'
|
||||
)
|
||||
self._parser.add_argument(
|
||||
'--input_size', type=int, default=256, required=False, help='The number of expected features in the input.'
|
||||
)
|
||||
self._parser.add_argument(
|
||||
'--hidden_size', type=int, default=1024, required=False, help='The number of features in the hidden state.'
|
||||
)
|
||||
self._parser.add_argument(
|
||||
'--num_layers', type=int, default=8, required=False, help='The number of recurrent layers.'
|
||||
)
|
||||
|
||||
self._parser.add_argument('--bidirectional', action='store_true', default=False, help='Bidirectional LSTM.')
|
||||
self._parser.add_argument('--seq_len', type=int, default=512, required=False, help='Sequence length.')
|
||||
|
||||
def _generate_dataset(self):
|
||||
"""Generate dataset for benchmarking according to shape info.
|
||||
|
||||
Return:
|
||||
True if dataset is created successfully.
|
||||
"""
|
||||
self._dataset = TorchRandomDataset(
|
||||
[self._args.sample_count, self._args.seq_len, self._args.input_size], self._world_size, dtype=torch.float
|
||||
)
|
||||
if len(self._dataset) == 0:
|
||||
logger.error('Generate random dataset failed - model: {}'.format(self._name))
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _create_model(self, precision):
|
||||
"""Construct the model for benchmarking.
|
||||
|
||||
Args:
|
||||
precision (Precision): precision of model and input data, such as float32, float16.
|
||||
"""
|
||||
try:
|
||||
self._model = LSTMBenchmarkModel(
|
||||
self._args.input_size, self._args.hidden_size, self._args.num_layers, self._args.bidirectional,
|
||||
self._args.num_classes
|
||||
)
|
||||
self._model = self._model.to(dtype=getattr(torch, precision.value))
|
||||
if self._gpu_available:
|
||||
self._model = self._model.cuda()
|
||||
except BaseException as e:
|
||||
logger.error(
|
||||
'Create model with specified precision failed - model: {}, precision: {}, message: {}.'.format(
|
||||
self._name, precision, str(e)
|
||||
)
|
||||
)
|
||||
return False
|
||||
|
||||
self._target = torch.LongTensor(self._args.batch_size).random_(self._args.num_classes)
|
||||
if self._gpu_available:
|
||||
self._target = self._target.cuda()
|
||||
|
||||
return True
|
||||
|
||||
def _train_step(self, precision):
|
||||
"""Define the training process.
|
||||
|
||||
Args:
|
||||
precision (Precision): precision of model and input data, such as float32, float16.
|
||||
|
||||
Return:
|
||||
The step-time list of every training step.
|
||||
"""
|
||||
duration = []
|
||||
curr_step = 0
|
||||
while True:
|
||||
for idx, sample in enumerate(self._dataloader):
|
||||
start = time.time()
|
||||
sample = sample.to(dtype=getattr(torch, precision.value))
|
||||
if self._gpu_available:
|
||||
sample = sample.cuda()
|
||||
self._optimizer.zero_grad()
|
||||
output = self._model(sample)
|
||||
loss = self._loss_fn(output, self._target)
|
||||
loss.backward()
|
||||
self._optimizer.step()
|
||||
end = time.time()
|
||||
curr_step += 1
|
||||
if curr_step > self._args.num_warmup:
|
||||
# Save the step time of every training/inference step, unit is millisecond.
|
||||
duration.append((end - start) * 1000)
|
||||
if self._is_finished(curr_step, end):
|
||||
return duration
|
||||
|
||||
def _inference_step(self, precision):
|
||||
"""Define the inference process.
|
||||
|
||||
Args:
|
||||
precision (Precision): precision of model and input data,
|
||||
such as float32, float16.
|
||||
|
||||
Return:
|
||||
The latency list of every inference operation.
|
||||
"""
|
||||
duration = []
|
||||
curr_step = 0
|
||||
with torch.no_grad():
|
||||
self._model.eval()
|
||||
while True:
|
||||
for idx, sample in enumerate(self._dataloader):
|
||||
start = time.time()
|
||||
sample = sample.to(dtype=getattr(torch, precision.value))
|
||||
if self._gpu_available:
|
||||
sample = sample.cuda()
|
||||
self._model(sample)
|
||||
if self._gpu_available:
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
curr_step += 1
|
||||
if curr_step > self._args.num_warmup:
|
||||
# Save the step time of every training/inference step, unit is millisecond.
|
||||
duration.append((end - start) * 1000)
|
||||
if self._is_finished(curr_step, end):
|
||||
return duration
|
||||
|
||||
|
||||
# Register LSTM benchmark.
|
||||
BenchmarkRegistry.register_benchmark(
|
||||
'pytorch-lstm', PytorchLSTM, parameters='--input_size=256 --hidden_size=1024 --num_layers=8'
|
||||
)
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Tests for LSTM model benchmarks."""
|
||||
|
||||
from tests.helper import decorator
|
||||
from superbench.benchmarks import BenchmarkRegistry, Platform, Framework, BenchmarkType, ReturnCode
|
||||
from superbench.benchmarks.model_benchmarks.pytorch_lstm import PytorchLSTM
|
||||
|
||||
|
||||
@decorator.cuda_test
|
||||
@decorator.pytorch_test
|
||||
def test_pytorch_lstm_with_gpu():
|
||||
"""Test pytorch-lstm benchmark with GPU."""
|
||||
run_pytorch_lstm(
|
||||
parameters='--batch_size 1 --num_classes 5 --seq_len 8 --num_warmup 2 --num_steps 4 \
|
||||
--model_action train inference',
|
||||
check_metrics=[
|
||||
'steptime_train_float32', 'throughput_train_float32', 'steptime_train_float16', 'throughput_train_float16',
|
||||
'steptime_inference_float32', 'throughput_inference_float32', 'steptime_inference_float16',
|
||||
'throughput_inference_float16'
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@decorator.pytorch_test
|
||||
def test_pytorch_lstm_no_gpu():
|
||||
"""Test pytorch-lstm benchmark with CPU."""
|
||||
run_pytorch_lstm(
|
||||
parameters='--batch_size 1 --num_classes 5 --seq_len 8 --num_warmup 2 --num_steps 4 \
|
||||
--model_action train inference --precision float32 --no_gpu',
|
||||
check_metrics=[
|
||||
'steptime_train_float32', 'throughput_train_float32', 'steptime_inference_float32',
|
||||
'throughput_inference_float32'
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def run_pytorch_lstm(parameters='', check_metrics=[]):
|
||||
"""Test pytorch-lstm benchmark."""
|
||||
context = BenchmarkRegistry.create_benchmark_context(
|
||||
'lstm', platform=Platform.CUDA, parameters=parameters, framework=Framework.PYTORCH
|
||||
)
|
||||
|
||||
assert (BenchmarkRegistry.is_benchmark_context_valid(context))
|
||||
|
||||
benchmark = BenchmarkRegistry.launch_benchmark(context)
|
||||
|
||||
# Check basic information.
|
||||
assert (benchmark)
|
||||
assert (isinstance(benchmark, PytorchLSTM))
|
||||
assert (benchmark.name == 'pytorch-lstm')
|
||||
assert (benchmark.type == BenchmarkType.MODEL)
|
||||
|
||||
# Check predefined parameters of lstm model.
|
||||
assert (benchmark._args.input_size == 256)
|
||||
assert (benchmark._args.hidden_size == 1024)
|
||||
assert (benchmark._args.num_layers == 8)
|
||||
|
||||
# Check parameters specified in BenchmarkContext.
|
||||
assert (benchmark._args.batch_size == 1)
|
||||
assert (benchmark._args.num_classes == 5)
|
||||
assert (benchmark._args.seq_len == 8)
|
||||
assert (benchmark._args.num_warmup == 2)
|
||||
assert (benchmark._args.num_steps == 4)
|
||||
|
||||
# Check dataset scale.
|
||||
assert (len(benchmark._dataset) == benchmark._args.sample_count * benchmark._world_size)
|
||||
|
||||
# Check results and metrics.
|
||||
assert (benchmark.run_count == 1)
|
||||
assert (benchmark.return_code == ReturnCode.SUCCESS)
|
||||
for metric in check_metrics:
|
||||
assert (len(benchmark.raw_data[metric]) == benchmark.run_count)
|
||||
assert (len(benchmark.raw_data[metric][0]) == benchmark._args.num_steps)
|
||||
assert (len(benchmark.result[metric]) == benchmark.run_count)
|
Загрузка…
Ссылка в новой задаче