Benchmarks - Support FP8 in BERT models (#446)
Support FP8 in PyTorch BERT models: * add fp8 hybrid/e4m3/e5m2 in precision arguments * build BERT encoders with `te.TransformerLayer` to repalce `transformers.BertModel` * wrap forward steps with fp8 autocast
This commit is contained in:
Родитель
65e433c0c6
Коммит
5197cdf5cb
|
@ -44,6 +44,9 @@ class BenchmarkType(Enum):
|
|||
|
||||
class Precision(Enum):
|
||||
"""The Enum class representing different data precisions."""
|
||||
FP8_HYBRID = 'fp8_hybrid'
|
||||
FP8_E4M3 = 'fp8_e4m3'
|
||||
FP8_E5M2 = 'fp8_e5m2'
|
||||
FLOAT16 = 'float16'
|
||||
FLOAT32 = 'float32'
|
||||
FLOAT64 = 'float64'
|
||||
|
|
|
@ -5,6 +5,11 @@
|
|||
|
||||
import torch
|
||||
from transformers import BertModel, BertConfig
|
||||
try:
|
||||
import transformer_engine.pytorch as te
|
||||
from transformer_engine.common.recipe import Format, DelayedScaling
|
||||
except ImportError:
|
||||
te = None
|
||||
|
||||
from superbench.common.utils import logger
|
||||
from superbench.benchmarks import BenchmarkRegistry, Precision
|
||||
|
@ -42,6 +47,55 @@ class BertBenchmarkModel(torch.nn.Module):
|
|||
return result
|
||||
|
||||
|
||||
class TeBertBenchmarkModel(torch.nn.Module):
|
||||
"""BERT model using Transformer Engine."""
|
||||
def __init__(self, config, num_classes):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
config (BertConfig): Configurations of BERT model.
|
||||
num_classes (int): The number of objects for classification.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._embedding = torch.nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
# Build BERT using nn.TransformerEncoderLayer or te.TransformerLayer
|
||||
# input shape: (seq_len, batch_size, hidden_size)
|
||||
encoder_layer = te.TransformerLayer(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
config.num_attention_heads,
|
||||
apply_residual_connection_post_layernorm=True,
|
||||
output_layernorm=True,
|
||||
layer_type='encoder',
|
||||
)
|
||||
self._encoder_layers = torch.nn.ModuleList([encoder_layer for _ in range(config.num_hidden_layers)])
|
||||
# BertPooler used in huggingface transformers
|
||||
# https://github.com/huggingface/transformers/blob/accad48e/src/transformers/models/bert/modeling_bert.py#L893
|
||||
self._pooler = torch.nn.Sequential(
|
||||
torch.nn.Linear(config.hidden_size, config.hidden_size),
|
||||
torch.nn.Tanh(),
|
||||
)
|
||||
self._linear = torch.nn.Linear(config.hidden_size, num_classes)
|
||||
|
||||
def forward(self, input):
|
||||
"""Forward propagation function.
|
||||
|
||||
Args:
|
||||
input (torch.LongTensor): Indices of input sequence tokens in the vocabulary,
|
||||
shape (batch_size, sequence_length).
|
||||
|
||||
Return:
|
||||
out (torch.FloatTensor): Last layer hidden-state of the first token of the sequence
|
||||
(classification token) further processed by a Linear layer, shape (batch_size, hidden_size).
|
||||
"""
|
||||
out = self._embedding(input.movedim(0, -1))
|
||||
for layer in self._encoder_layers:
|
||||
out = layer(out, attention_mask=None)
|
||||
out = self._linear(self._pooler(out.movedim(0, 1)[:, 0]))
|
||||
return out
|
||||
|
||||
|
||||
class PytorchBERT(PytorchBase):
|
||||
"""The BERT benchmark class."""
|
||||
def __init__(self, name, parameters=''):
|
||||
|
@ -53,7 +107,14 @@ class PytorchBERT(PytorchBase):
|
|||
"""
|
||||
super().__init__(name, parameters)
|
||||
self._config = None
|
||||
self._supported_precision = [Precision.FLOAT32, Precision.FLOAT16]
|
||||
self._fp8_recipe = None
|
||||
self._supported_precision = [
|
||||
Precision.FLOAT32,
|
||||
Precision.FLOAT16,
|
||||
Precision.FP8_HYBRID,
|
||||
Precision.FP8_E4M3,
|
||||
Precision.FP8_E5M2,
|
||||
]
|
||||
self._optimizer_type = Optimizer.ADAMW
|
||||
self._loss_fn = torch.nn.CrossEntropyLoss()
|
||||
|
||||
|
@ -105,9 +166,31 @@ class PytorchBERT(PytorchBase):
|
|||
intermediate_size=self._args.intermediate_size
|
||||
)
|
||||
|
||||
enable_fp8 = precision.name.startswith('FP8_')
|
||||
if enable_fp8 and te is None:
|
||||
logger.error(
|
||||
f'Create model with fp8 failed - model: {self._name}, precision: {precision},'
|
||||
' message: Cannot find transformer_engine.'
|
||||
)
|
||||
return False
|
||||
if enable_fp8 and not self._gpu_available:
|
||||
logger.error(
|
||||
f'Create model with fp8 failed - model: {self._name}, precision: {precision},'
|
||||
' message: FP8 is only supported on GPU.'
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
self._model = BertBenchmarkModel(self._config, self._args.num_classes)
|
||||
self._model = self._model.to(dtype=getattr(torch, precision.value))
|
||||
if enable_fp8:
|
||||
self._fp8_recipe = DelayedScaling(
|
||||
fp8_format=Format[precision.name.strip('FP8_')],
|
||||
amax_history_len=16,
|
||||
amax_compute_algo='max',
|
||||
)
|
||||
self._model = TeBertBenchmarkModel(self._config, self._args.num_classes).to(dtype=torch.float16)
|
||||
else:
|
||||
self._model = BertBenchmarkModel(self._config, self._args.num_classes)
|
||||
self._model = self._model.to(dtype=getattr(torch, precision.value))
|
||||
if self._gpu_available:
|
||||
self._model = self._model.cuda()
|
||||
except BaseException as e:
|
||||
|
@ -142,7 +225,11 @@ class PytorchBERT(PytorchBase):
|
|||
if self._gpu_available:
|
||||
sample = sample.cuda()
|
||||
self._optimizer.zero_grad()
|
||||
output = self._model(sample)
|
||||
if self._fp8_recipe is not None:
|
||||
with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe):
|
||||
output = self._model(sample)
|
||||
else:
|
||||
output = self._model(sample)
|
||||
loss = self._loss_fn(output, self._target)
|
||||
loss.backward()
|
||||
self._optimizer.step()
|
||||
|
@ -174,7 +261,11 @@ class PytorchBERT(PytorchBase):
|
|||
start = self._timer()
|
||||
if self._gpu_available:
|
||||
sample = sample.cuda()
|
||||
self._model(sample)
|
||||
if self._fp8_recipe is not None:
|
||||
with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe):
|
||||
self._model(sample)
|
||||
else:
|
||||
self._model(sample)
|
||||
end = self._timer()
|
||||
curr_step += 1
|
||||
if curr_step > self._args.num_warmup:
|
||||
|
|
|
@ -169,8 +169,9 @@ def test_arguments_related_interfaces():
|
|||
--num_warmup int The number of warmup step.
|
||||
--pin_memory Enable option to pin memory in data loader.
|
||||
--precision Precision [Precision ...]
|
||||
Model precision. E.g. float16 float32 float64 bfloat16
|
||||
uint8 int8 int16 int32 int64.
|
||||
Model precision. E.g. fp8_hybrid fp8_e4m3 fp8_e5m2
|
||||
float16 float32 float64 bfloat16 uint8 int8 int16
|
||||
int32 int64.
|
||||
--run_count int The run count of benchmark.
|
||||
--sample_count int The number of data samples in dataset.
|
||||
--seq_len int Sequence length."""
|
||||
|
@ -207,8 +208,9 @@ def test_preprocess():
|
|||
--num_warmup int The number of warmup step.
|
||||
--pin_memory Enable option to pin memory in data loader.
|
||||
--precision Precision [Precision ...]
|
||||
Model precision. E.g. float16 float32 float64 bfloat16
|
||||
uint8 int8 int16 int32 int64.
|
||||
Model precision. E.g. fp8_hybrid fp8_e4m3 fp8_e5m2
|
||||
float16 float32 float64 bfloat16 uint8 int8 int16
|
||||
int32 int64.
|
||||
--run_count int The run count of benchmark.
|
||||
--sample_count int The number of data samples in dataset.
|
||||
--seq_len int Sequence length."""
|
||||
|
|
Загрузка…
Ссылка в новой задаче