Benchmarks - Support FP8 in BERT models (#446)

Support FP8 in PyTorch BERT models:

* add fp8 hybrid/e4m3/e5m2 in precision arguments
* build BERT encoders with `te.TransformerLayer` to repalce
`transformers.BertModel`
* wrap forward steps with fp8 autocast
This commit is contained in:
Yifan Xiong 2023-01-04 11:12:05 +08:00 коммит произвёл GitHub
Родитель 65e433c0c6
Коммит 5197cdf5cb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 105 добавлений и 9 удалений

Просмотреть файл

@ -44,6 +44,9 @@ class BenchmarkType(Enum):
class Precision(Enum):
"""The Enum class representing different data precisions."""
FP8_HYBRID = 'fp8_hybrid'
FP8_E4M3 = 'fp8_e4m3'
FP8_E5M2 = 'fp8_e5m2'
FLOAT16 = 'float16'
FLOAT32 = 'float32'
FLOAT64 = 'float64'

Просмотреть файл

@ -5,6 +5,11 @@
import torch
from transformers import BertModel, BertConfig
try:
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
except ImportError:
te = None
from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Precision
@ -42,6 +47,55 @@ class BertBenchmarkModel(torch.nn.Module):
return result
class TeBertBenchmarkModel(torch.nn.Module):
"""BERT model using Transformer Engine."""
def __init__(self, config, num_classes):
"""Constructor.
Args:
config (BertConfig): Configurations of BERT model.
num_classes (int): The number of objects for classification.
"""
super().__init__()
self._embedding = torch.nn.Embedding(config.vocab_size, config.hidden_size)
# Build BERT using nn.TransformerEncoderLayer or te.TransformerLayer
# input shape: (seq_len, batch_size, hidden_size)
encoder_layer = te.TransformerLayer(
config.hidden_size,
config.intermediate_size,
config.num_attention_heads,
apply_residual_connection_post_layernorm=True,
output_layernorm=True,
layer_type='encoder',
)
self._encoder_layers = torch.nn.ModuleList([encoder_layer for _ in range(config.num_hidden_layers)])
# BertPooler used in huggingface transformers
# https://github.com/huggingface/transformers/blob/accad48e/src/transformers/models/bert/modeling_bert.py#L893
self._pooler = torch.nn.Sequential(
torch.nn.Linear(config.hidden_size, config.hidden_size),
torch.nn.Tanh(),
)
self._linear = torch.nn.Linear(config.hidden_size, num_classes)
def forward(self, input):
"""Forward propagation function.
Args:
input (torch.LongTensor): Indices of input sequence tokens in the vocabulary,
shape (batch_size, sequence_length).
Return:
out (torch.FloatTensor): Last layer hidden-state of the first token of the sequence
(classification token) further processed by a Linear layer, shape (batch_size, hidden_size).
"""
out = self._embedding(input.movedim(0, -1))
for layer in self._encoder_layers:
out = layer(out, attention_mask=None)
out = self._linear(self._pooler(out.movedim(0, 1)[:, 0]))
return out
class PytorchBERT(PytorchBase):
"""The BERT benchmark class."""
def __init__(self, name, parameters=''):
@ -53,7 +107,14 @@ class PytorchBERT(PytorchBase):
"""
super().__init__(name, parameters)
self._config = None
self._supported_precision = [Precision.FLOAT32, Precision.FLOAT16]
self._fp8_recipe = None
self._supported_precision = [
Precision.FLOAT32,
Precision.FLOAT16,
Precision.FP8_HYBRID,
Precision.FP8_E4M3,
Precision.FP8_E5M2,
]
self._optimizer_type = Optimizer.ADAMW
self._loss_fn = torch.nn.CrossEntropyLoss()
@ -105,9 +166,31 @@ class PytorchBERT(PytorchBase):
intermediate_size=self._args.intermediate_size
)
enable_fp8 = precision.name.startswith('FP8_')
if enable_fp8 and te is None:
logger.error(
f'Create model with fp8 failed - model: {self._name}, precision: {precision},'
' message: Cannot find transformer_engine.'
)
return False
if enable_fp8 and not self._gpu_available:
logger.error(
f'Create model with fp8 failed - model: {self._name}, precision: {precision},'
' message: FP8 is only supported on GPU.'
)
return False
try:
self._model = BertBenchmarkModel(self._config, self._args.num_classes)
self._model = self._model.to(dtype=getattr(torch, precision.value))
if enable_fp8:
self._fp8_recipe = DelayedScaling(
fp8_format=Format[precision.name.strip('FP8_')],
amax_history_len=16,
amax_compute_algo='max',
)
self._model = TeBertBenchmarkModel(self._config, self._args.num_classes).to(dtype=torch.float16)
else:
self._model = BertBenchmarkModel(self._config, self._args.num_classes)
self._model = self._model.to(dtype=getattr(torch, precision.value))
if self._gpu_available:
self._model = self._model.cuda()
except BaseException as e:
@ -142,7 +225,11 @@ class PytorchBERT(PytorchBase):
if self._gpu_available:
sample = sample.cuda()
self._optimizer.zero_grad()
output = self._model(sample)
if self._fp8_recipe is not None:
with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe):
output = self._model(sample)
else:
output = self._model(sample)
loss = self._loss_fn(output, self._target)
loss.backward()
self._optimizer.step()
@ -174,7 +261,11 @@ class PytorchBERT(PytorchBase):
start = self._timer()
if self._gpu_available:
sample = sample.cuda()
self._model(sample)
if self._fp8_recipe is not None:
with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe):
self._model(sample)
else:
self._model(sample)
end = self._timer()
curr_step += 1
if curr_step > self._args.num_warmup:

Просмотреть файл

@ -169,8 +169,9 @@ def test_arguments_related_interfaces():
--num_warmup int The number of warmup step.
--pin_memory Enable option to pin memory in data loader.
--precision Precision [Precision ...]
Model precision. E.g. float16 float32 float64 bfloat16
uint8 int8 int16 int32 int64.
Model precision. E.g. fp8_hybrid fp8_e4m3 fp8_e5m2
float16 float32 float64 bfloat16 uint8 int8 int16
int32 int64.
--run_count int The run count of benchmark.
--sample_count int The number of data samples in dataset.
--seq_len int Sequence length."""
@ -207,8 +208,9 @@ def test_preprocess():
--num_warmup int The number of warmup step.
--pin_memory Enable option to pin memory in data loader.
--precision Precision [Precision ...]
Model precision. E.g. float16 float32 float64 bfloat16
uint8 int8 int16 int32 int64.
Model precision. E.g. fp8_hybrid fp8_e4m3 fp8_e5m2
float16 float32 float64 bfloat16 uint8 int8 int16
int32 int64.
--run_count int The run count of benchmark.
--sample_count int The number of data samples in dataset.
--seq_len int Sequence length."""