Benchmarks: Add benchmark: Megatron-LM/Megatron-Deepspeed GPT pretrain benchmark (#582)

**Description**
Megatron-LM/Megatron-Deepspeed GPT pretrain benchmark
This commit is contained in:
Yuting Jiang 2023-12-07 09:37:09 +08:00 коммит произвёл GitHub
Родитель 254ea7feba
Коммит dd5a6329ed
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
14 изменённых файлов: 2425 добавлений и 9 удалений

Просмотреть файл

@ -41,6 +41,7 @@ RUN apt-get update && \
libtinfo5 \
libtool \
lshw \
python3-mpi4py \
net-tools \
openssh-client \
openssh-server \

Просмотреть файл

@ -41,6 +41,7 @@ RUN apt-get update && \
libtinfo5 \
libtool \
lshw \
python3-mpi4py \
net-tools \
openssh-client \
openssh-server \

Просмотреть файл

@ -41,6 +41,7 @@ RUN apt-get update && \
libtinfo5 \
libtool \
lshw \
python3-mpi4py \
net-tools \
numactl \
openssh-client \
@ -136,7 +137,7 @@ RUN echo PATH="$PATH" > /etc/environment && \
WORKDIR ${SB_HOME}
ADD third_party third_party
RUN make -C third_party rocm -o rocm_hipblaslt
RUN make -C third_party rocm -o rocm_hipblaslt -o megatron_deepspeed -o megatron_lm
ADD . .
RUN python3 -m pip install --upgrade setuptools==65.7 && \

Просмотреть файл

@ -41,6 +41,7 @@ RUN apt-get update && \
libtinfo5 \
libtool \
lshw \
python3-mpi4py \
net-tools \
numactl \
openssh-client \
@ -141,7 +142,7 @@ RUN echo PATH="$PATH" > /etc/environment && \
WORKDIR ${SB_HOME}
ADD third_party third_party
RUN make ROCBLAS_BRANCH=release/rocm-rel-5.1 -C third_party rocm -o rocm_hipblaslt
RUN make ROCBLAS_BRANCH=release/rocm-rel-5.1 -C third_party rocm -o rocm_hipblaslt -o megatron_deepspeed -o megatron_lm
ADD . .
RUN python3 -m pip install --no-cache-dir .[amdworker] && \

Просмотреть файл

@ -37,8 +37,29 @@ For inference, supported percentiles include
| Name | Unit | Description |
|-----------------------------------------------------------------------------------------|------------------------|------------------------------------------------------------------------------|
| model-benchmarks/pytorch-${model_name}/${precision}_train_step_time | time (ms) | The average training step time with fp32/fp16 precision. |
| model-benchmarks/pytorch-${model_name}/${precision}_train_throughput | throughput (samples/s) | The average training throughput with fp32/fp16 precision. |
| model-benchmarks/pytorch-${model_name}/${precision}_train_throughput | throughput (samples/s) | The average training throughput with fp32/fp16 precision per GPU. |
| model-benchmarks/pytorch-${model_name}/${precision}_inference_step_time | time (ms) | The average inference step time with fp32/fp16 precision. |
| model-benchmarks/pytorch-${model_name}/${precision}_inference_throughput | throughput (samples/s) | The average inference throughput with fp32/fp16 precision. |
| model-benchmarks/pytorch-${model_name}/${precision}_inference_step_time\_${percentile} | time (ms) | The n<sup>th</sup> percentile inference step time with fp32/fp16 precision. |
| model-benchmarks/pytorch-${model_name}/${precision}_inference_throughput\_${percentile} | throughput (samples/s) | The n<sup>th</sup> percentile inference throughput with fp32/fp16 precision. |
## Megatron Model benchmarks
### `megatron-gpt`
#### Introduction
Run GPT pretrain tasks with float32, float16, bfloat16 precisions with [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) or [Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed).
`tips: batch_size in this benchmark represents global batch size, the batch size on each GPU instance is micro_batch_size.`
#### Metrics
| Name | Unit | Description |
|---------------------------------------------------|------------------------|---------------------------------------------------------|
| megatron-gpt/${precision}_train_step_time | time (ms) | The average training step time per iteration. |
| megatron-gpt/${precision}_train_throughput | throughput (samples/s) | The average training throughput per iteration. |
| megatron-gpt/${precision}_train_tflops | tflops/s | The average training tflops per second per iteration. |
| megatron-gpt/${precision}_train_mem_allocated | GB | The average GPU memory allocated per iteration. |
| megatron-gpt/${precision}_train_max_mem_allocated | GB | The average maximum GPU memory allocated per iteration. |

Просмотреть файл

@ -177,6 +177,7 @@ setup(
'xlrd>=2.0.1',
'xlsxwriter>=1.3.8',
'xmltodict>=0.12.0',
'types-requests',
],
extras_require=(
lambda x: {

Просмотреть файл

@ -8,5 +8,6 @@ from superbench.benchmarks.model_benchmarks.pytorch_bert import PytorchBERT
from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import PytorchGPT2
from superbench.benchmarks.model_benchmarks.pytorch_cnn import PytorchCNN
from superbench.benchmarks.model_benchmarks.pytorch_lstm import PytorchLSTM
from superbench.benchmarks.model_benchmarks.megatron_gpt3 import MegatronGPT
__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM']
__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT']

Просмотреть файл

@ -0,0 +1,508 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Module of the megatron deepspeed GPT pretrain class."""
import json
import os
import statistics
import numpy as np
import requests
import torch
from pathlib import Path
import re
from superbench.benchmarks import BenchmarkRegistry
from superbench.benchmarks.context import Platform, Precision
from superbench.benchmarks.model_benchmarks.model_base import ModelBenchmark
from superbench.benchmarks.return_code import ReturnCode
from superbench.common.utils import logger, run_command
def download_file(url, path):
"""Download file from url to path."""
response = requests.get(url)
with open(path, 'wb') as file:
file.write(response.content)
class MegatronGPT(ModelBenchmark):
"""The Megatron DeepSpeed GPT pretrain benchmark class."""
def __init__(self, name, parameters=''):
"""Constructor.
Args:
name (str): benchmark name.
parameters (str): parameters of the benchmark.
"""
super().__init__(name, parameters)
self._supported_precision = [Precision.FLOAT32, Precision.FLOAT16, Precision.BFLOAT16]
def add_parser_arguments(self):
"""Add the specified arguments."""
super().add_parser_arguments()
self._parser.add_argument('--code_base', type=str, required=False, default='', help='Code base.')
self._parser.add_argument('--dataset_url', type=str, required=False, default=None, help='Dataset URL.')
self._parser.add_argument(
'--vocab_url',
type=str,
required=False,
default='https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json',
help='Vocab URL.'
)
self._parser.add_argument(
'--merges_url',
type=str,
required=False,
default='https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt',
help='Merges URL.'
)
self._parser.add_argument(
'--tokenizer_type', type=str, required=False, default='GPT2BPETokenizer', help='Tokenizer type.'
)
self._parser.add_argument('--model_size', type=int, required=False, default=6.7, help='Model size.')
self._parser.add_argument('--num_layers', type=int, required=False, default=32, help='Number of layers.')
self._parser.add_argument('--hidden_size', type=int, required=False, default=4096, help='Hidden size.')
self._parser.add_argument(
'--num_attn_heads', type=int, required=False, default=32, help='Number of attention heads.'
)
self._parser.add_argument('--micro_batch_size', type=int, required=False, default=2, help='micro batch size.')
self._parser.add_argument('--lr', type=float, required=False, default=1.2e-4, help='Learning rate.')
self._parser.add_argument('--min_lr', type=float, required=False, default=1.0e-6, help='Minimum learning rate.')
self._parser.add_argument('--init_std', type=float, required=False, default=0.009, help='Init std.')
self._parser.add_argument('--seq_len', type=int, required=False, default=2048, help='Sequence length.')
self._parser.add_argument(
'--tensor_model_parallel_size', type=int, required=False, default=1, help='Tensor model parallel size.'
)
self._parser.add_argument(
'--pipeline_model_parallel_size', type=int, required=False, default=1, help='Pipeline model parallel size.'
)
self._parser.add_argument(
'--num_gpus', type=int, required=False, default=8, help='Number of GPUs per node to run the benchmark.'
)
self._parser.add_argument(
'--num_nodes', type=int, required=False, default=1, help='Number of nodes to run the benchmark.'
)
self._parser.add_argument('--sequence_parallel', action='store_true', help='Enable Sequence parallel.')
self._parser.add_argument(
'--no_async_tensor_model_parallel_allreduce',
action='store_true',
help='No async tensor model parallel allreduce.'
)
self._parser.add_argument(
'--use_rotary_position_embeddings', action='store_true', help='Use rotary position embeddings.'
)
self._parser.add_argument(
'--no_gradient_accumulation_fusion', action='store_true', help='No gradient accumulation fusion.'
)
self._parser.add_argument('--use_flash_attn', action='store_true', help='Use flash attention.')
self._parser.add_argument('--no_masked_softmax_fusion', action='store_true', help='No masked softmax fusion.')
self._parser.add_argument('--no_bias_gelu_fusion', action='store_true', help='No bias gelu fusion.')
self._parser.add_argument('--no_bias_dropout_fusion', action='store_true', help='No bias dropout fusion.')
self._parser.add_argument(
'--train_tokens', type=int, required=False, default=300000000000, help='Train tokens.'
)
# lr configs
# Parallelism configs
self._parser.add_argument('--zero_stage', type=int, default=1, help='Zero stage.')
# Misc configs
self._parser.add_argument('--log-interval', type=int, required=False, default=1, help='Log interval.')
self._parser.add_argument('--eval_iters', type=int, default=0, help='Eval iters.')
self._parser.add_argument('--eval_interval', type=int, default=10, help='Eval interval.')
self._parser.add_argument('--num_save', type=int, default=10000, help='Num save.')
self._parser.add_argument('--save_interval', type=int, default=10000, help='Save interval.')
# Output and data configs
self._parser.add_argument('--seed', type=int, default=1234, help='Seed.')
self._parser.add_argument('--data_home', type=str, default='/tmp', help='Data home.')
self._parser.add_argument('--vocab_path', type=str, default='/tmp/gpt2-vocab.json', help='Vocab path.')
self._parser.add_argument('--merge_path', type=str, default='/tmp/gpt2-merges.txt', help='Merge path.')
self._parser.add_argument('--prescale_grad', action='store_true', help='Prescale grad.')
self._parser.add_argument(
'--hostfile', type=str, default=None, help='Hostfile to run the mutli-node benchmark.'
)
self._parser.add_argument('--data_impl', type=str, default='mmap', help='Data impl.')
self._parser.add_argument('--data_prefix', type=str, default='dataset_text_document', help='Data prefix.')
self._parser.add_argument('--deepspeed', action='store_true', help='Use deepspeed.')
self._parser.add_argument('--extra', type=str, default=None, help='Extra options for Megatron.')
def _preprocess(self):
if not super()._preprocess():
return False
if not os.path.exists(self._args.code_base) or \
not os.path.exists(os.path.join(self._args.code_base, 'pretrain_gpt.py')):
logger.error('Code base is not valid.')
self._result.set_return_code(ReturnCode.INVALID_ARGUMENT)
return False
data_parallel_size = self._args.num_gpus * self._num_nodes \
// self._args.pipeline_model_parallel_size // self._args.tensor_model_parallel_size
if self._args.micro_batch_size < 1 or \
self._args.micro_batch_size > (self._args.batch_size // data_parallel_size):
logger.error('Micro Batch size * data parallel size is larger than global batch size.')
self._result.set_return_code(ReturnCode.INVALID_ARGUMENT)
return False
for precision in self._args.precision:
if precision not in self._supported_precision:
logger.error('Precision %s is not supported.' % precision)
self._result.set_return_code(ReturnCode.INVALID_ARGUMENT)
return False
if not os.path.exists(self._args.data_home):
os.makedirs(self._args.data_home)
return True
def _parse_log(self, output):
"""Parse log output and get the performance."""
tflops_pattern = re.compile(r'TFLOPs: (\d+\.\d+)')
elapsed_time_pattern = re.compile(r'elapsed time per iteration \(ms\): (\d+\.\d+)')
mem_allocated_pattern = re.compile(r'MemAllocated=([\d.]+)[KMGTPEZY]?B')
max_mem_allocated_pattern = re.compile(r'MaxMemAllocated=([\d.]+)[KMGTPEZY]?B')
lines = output.splitlines()
tflops = []
mem_allocated = []
max_mem_allocated = []
iteration_times = []
for line in lines:
if 'TFLOPs' in line:
tflops_matches = tflops_pattern.search(line)
elapsed_time_match = elapsed_time_pattern.search(line)
if tflops_matches:
tflops_values = float(tflops_matches.group(1))
tflops.append(tflops_values)
if elapsed_time_match:
elapsed_time_value = float(elapsed_time_match.group(1))
iteration_times.append(elapsed_time_value)
if 'MaxMemAllocated' in line:
mem_allocated_match = mem_allocated_pattern.search(line)
max_mem_allocated_match = max_mem_allocated_pattern.search(line)
if mem_allocated_match:
mem_allocated_value = float(mem_allocated_match.group(1))
mem_allocated.append(mem_allocated_value)
if max_mem_allocated_match:
max_mem_allocated_value = float(max_mem_allocated_match.group(1))
max_mem_allocated.append(max_mem_allocated_value)
return iteration_times, tflops, mem_allocated, max_mem_allocated
def __prepare_deespeed_config(self, precision_megatron):
"""Prepare deepspeed configs."""
self._config_json_path = os.path.join(self._args.data_home, 'ds_config_gpt.json')
# Load deepspeed config template json file
precision_template = {
'enabled': True,
'loss_scale': 0,
'loss_scale_window': 500,
'hysteresis': 2,
'min_loss_scale': 1,
'initial_scale_power': 11
}
ds_config_template = {
'train_batch_size': self._args.batch_size,
'train_micro_batch_size_per_gpu': self._args.micro_batch_size,
'steps_per_print': self._args.log_interval,
'zero_optimization': {
'stage': self._args.zero_stage
},
'gradient_clipping': 1.0,
'prescale_gradients': self._args.prescale_grad,
}
if len(precision_megatron) > 0:
ds_config_template[precision_megatron] = precision_template
# Write to config json file
with open(self._config_json_path, 'w') as file:
json.dump(ds_config_template, file, indent=4)
deepspeed_options = f'\
--deepspeed \
--deepspeed_config {self._config_json_path} \
--zero-stage {self._args.zero_stage} \
--pipeline-model-parallel-size {self._args.pipeline_model_parallel_size}'
if self._args.pipeline_model_parallel_size <= 1:
deepspeed_options = f'{deepspeed_options} --no-pipeline-parallel'
return deepspeed_options
def _megatron_command(self, precision): # noqa: C901
"""Generate megatron command."""
if precision == Precision.FLOAT32:
precision_megatron = ''
elif precision == Precision.FLOAT16:
precision_megatron = '--fp16'
elif precision == Precision.BFLOAT16:
precision_megatron = '--bf16'
megatron_options = f'\
--override-opt_param-scheduler \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--tensor-model-parallel-size {self._args.tensor_model_parallel_size} \
--init-method-std {self._args.init_std} \
--lr-decay-samples 43945312 \
--lr-warmup-samples {self._args.num_warmup * self._args.batch_size} \
--lr-decay-style cosine \
--micro-batch-size {self._args.micro_batch_size} \
--global-batch-size {self._args.batch_size} \
--num-layers {self._args.num_layers} \
--hidden-size {self._args.hidden_size} \
--num-attention-heads {self._args.num_attn_heads} \
--seq-length {self._args.seq_len} \
--max-position-embeddings {self._args.seq_len} \
--train-tokens {self._args.train_tokens} \
--train-samples {self._args.num_steps * self._args.batch_size} \
--lr {self._args.lr} \
--min-lr {self._args.min_lr} \
--split 949,50,1 \
--log-interval {self._args.log_interval} \
--eval-interval {self._args.eval_interval} \
--eval-iters {self._args.eval_iters} \
--save-interval {self._args.save_interval} \
--weight-decay 0.1 \
--clip-grad 1.0 \
--hysteresis 2 \
--num-workers {self._args.num_workers} \
--attention-dropout 0.0 \
--hidden-dropout 0.0 \
--optimizer adam \
--use-distributed-optimizer \
{precision_megatron} \
--seed {self._args.seed}'
if self._args.sequence_parallel:
megatron_options = f'{megatron_options} --sequence-parallel'
if self._args.no_async_tensor_model_parallel_allreduce:
megatron_options = f'{megatron_options} --no-async-tensor-model-parallel-allreduce'
if self._args.use_rotary_position_embeddings:
megatron_options = f'{megatron_options} --use-rotary-position-embeddings'
if self._args.no_gradient_accumulation_fusion:
megatron_options = f'{megatron_options} --no-gradient-accumulation-fusion'
if self._args.use_flash_attn:
megatron_options = f'{megatron_options} --use-flash-attn'
if self._args.no_masked_softmax_fusion:
megatron_options = f'{megatron_options} --no-masked-softmax-fusion'
if self._args.no_bias_gelu_fusion:
megatron_options = f'{megatron_options} --no-bias-gelu-fusion'
if self._args.no_bias_dropout_fusion:
megatron_options = f'{megatron_options} --no-bias-dropout-fusion'
if self._args.extra:
megatron_options = f'{megatron_options} {self._args.extra}'
command = ''
script_path = os.path.join(self._args.code_base, 'pretrain_gpt.py')
if self._args.deepspeed:
deepspeed_option = self.__prepare_deespeed_config(precision_megatron.lstrip('--'))
if self._num_nodes > 1:
command = f'torchrun {self._distributed_args} ' + \
f'{script_path} {megatron_options} {self._data_options} {deepspeed_option}'
else:
command = f'deepspeed {script_path} {megatron_options} {self._data_options} {deepspeed_option}'
else:
command = f'torchrun {self._distributed_args} {script_path} {megatron_options} {self._data_options}'
return command
def _train_step(self, precision): # noqa: E501
"""Train the model and get the performance."""
command = self._megatron_command(precision)
local_rank = os.environ.pop('OMPI_COMM_WORLD_LOCAL_RANK', None)
logger.info('Running command: {}.'.format(command))
output = run_command(command, flush_output=True)
os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] = local_rank
iteration_times = []
info = {}
# last rank will print the result, first rank will print the memory usage
if self._num_nodes == 1 or \
int(os.environ['OMPI_COMM_WORLD_RANK']) == int(os.environ['OMPI_COMM_WORLD_SIZE']) - 1 \
or int(os.environ['OMPI_COMM_WORLD_RANK']) == 0:
iteration_times, tflops, mem_allocated, max_mem_allocated = self._parse_log(output.stdout)
if len(tflops) > 0:
info['tflops'] = tflops
if len(mem_allocated) > 0:
info['mem_allocated'] = mem_allocated
if len(max_mem_allocated) > 0:
info['max_mem_allocated'] = max_mem_allocated
if not iteration_times:
iteration_times = [-1 for i in range(self._args.num_steps)]
return iteration_times, info
def _sync_result(self, data):
"""Sync the result of model benchmarking.
Args:
data (list): the data to be reduced.
"""
from mpi4py import MPI
comm = MPI.COMM_WORLD
data = np.array(data, dtype=np.float64)
# Reduce the data to a single value on rank 0
result = np.zeros_like(data)
comm.Allreduce([data, MPI.DOUBLE], [result, MPI.DOUBLE], op=MPI.MAX)
return result.tolist()
def _process_info(self, model_action, precision, info):
"""Process the result of model benchmarking."""
precision_metric = {'float16': 'fp16', 'float32': 'fp32', 'bfloat16': 'bf16'}
if precision.value in precision_metric.keys():
precision = precision_metric[precision.value]
for key, values in info.items():
metric = '{}_{}_{}'.format(precision, model_action, key)
self._result.add_raw_data(metric, values, self._args.log_raw_data)
self._result.add_result(metric, statistics.mean(values))
logger.info(
'Average {} - round: {}, model: {}, precision: {}, value: {:.6f}.'.format(
key, self._curr_run_index, self._name, precision, statistics.mean(values)
)
)
def _judge_gpu_availability(self):
"""Judge GPUs' availability according to arguments and running environment."""
self._gpu_available = not self._args.no_gpu and torch.cuda.is_available()
def _init_distributed_setting(self):
"""Initialize the distributed library and bind the worker to GPU.
Return:
True if distributed library is initialized successfully.
"""
if not os.getenv('OMPI_COMM_WORLD_SIZE'):
logger.error('MPI is not enabled.')
return False
self._num_nodes = int(os.getenv('OMPI_COMM_WORLD_SIZE')) // int(os.getenv('OMPI_COMM_WORLD_LOCAL_SIZE'))
if self._num_nodes > 1:
if not self._args.hostfile:
sb_hostfile = os.path.join(os.environ.get('SB_WORKSPACE', '.'), 'hostfile')
if os.path.exists(sb_hostfile):
hosts = open(sb_hostfile).read().split('\n')
hosts = [f'{host} slots={self._args.num_gpus}' for host in hosts if host != '']
self._args.hostfile = os.path.join(self._args.data_home, 'hostfile')
with open(self._args.hostfile, 'w') as file:
file.write('\n'.join(hosts))
if not os.path.exists(self._args.hostfile):
logger.error('Hostfile not found.')
return False
hosts = open(self._args.hostfile, 'r').readlines()
if self._num_nodes != len(hosts):
logger.error('MPI init failed since hostfile not match the MPI setting.')
return False
addr = os.getenv('MASTER_ADDR', hosts[0].split()[0])
port = os.getenv('MASTER_PORT', '29500')
node_rank = int(os.environ['OMPI_COMM_WORLD_RANK']) // int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE'])
self._distributed_args = f'--nproc_per_node {self._args.num_gpus} --nnodes {self._num_nodes} ' + \
f'--node_rank {node_rank} --master_addr {addr} --master_port {port}'
return True
def _generate_dataset(self):
"""Generate dataset for benchmarking.
Return:
True if dataset is created successfully.
"""
self._vocab_path = str(Path(self._args.data_home) / 'gpt2-vocab.json')
download_file(self._args.vocab_url, self._vocab_path)
self._merges_path = str(Path(self._args.data_home) / 'gpt2-merges.txt')
download_file(self._args.merges_url, self._merges_path)
if not os.path.exists(os.path.join(self._args.data_home, f'{self._args.data_prefix}.bin')) \
or not os.path.exists(os.path.join(self._args.data_home, f'{self._args.data_prefix}.idx')):
if self._args.dataset_url:
self._raw_data_path = str(Path(self._args.data_home) / 'data.json')
download_file(self._args.dataset_url, self._raw_data_path)
command = (
'python3 '
f'{os.path.join(self._args.code_base, "tools/preprocess_data.py")} '
f'--input {self._raw_data_path} '
f'--tokenizer-type {self._args.tokenizer_type} '
f'--output-prefix {os.path.join(self._args.data_home, "dataset")} '
f'--workers {str(self._args.num_workers)} '
f'--vocab-file {self._vocab_path} '
f'--merge-file {self._merges_path}'
)
# split documents
run_command(command, flush_output=True)
# binarize dataset
run_command(command, flush_output=True)
if not os.path.exists(os.path.join(self._args.data_home, f'{self._args.data_prefix}.bin')) \
or not os.path.exists(os.path.join(self._args.data_home, f'{self._args.data_prefix}.idx')):
logger.error('Dataset failed to generate.')
self._result.set_return_code(ReturnCode.DATASET_GENERATION_FAILURE)
return False
else:
logger.error('No dataset or dataset url provided.')
self._result.set_return_code(ReturnCode.DATASET_GENERATION_FAILURE)
return False
self._data_path = os.path.join(self._args.data_home, f'{self._args.data_prefix}')
self._data_options = f'\
--vocab-file {self._vocab_path} \
--merge-file {self._merges_path} \
--data-path {self._data_path} \
--data-impl {self._args.data_impl}'
logger.info('Dataset preparation successfully.')
return True
def _set_force_fp32(self):
"""Set force FP32."""
pass
def _init_dataloader(self):
"""Initialize the dataloader.
Return:
True if dataloader is created successfully.
"""
return True
def _create_optimizer(self):
"""Create the optimzier instance used for training and wrap with distributed library if need.
Return:
True if optimizer instance is created successfully.
"""
return True
def _create_model(self, precision):
"""Construct the model for benchmarking.
Args:
precision (Precision): precision of model and input data, such as float32, float16.
"""
return True
def _inference_step(self, precision):
"""Define the inference process.
Args:
precision (Precision): precision of model and input data,
such as float32, float16.
Return:
The latency list of every inference operation.
"""
pass
def _cal_params_count(self):
"""Calculate the parameters scale of the model.
Return:
The count of trainable parameters.
"""
pass
# Register GPT3 benchmark.
BenchmarkRegistry.register_benchmark('megatron-gpt', MegatronGPT, parameters='', platform=Platform.CUDA)
BenchmarkRegistry.register_benchmark('megatron-gpt', MegatronGPT, parameters='', platform=Platform.ROCM)

Просмотреть файл

@ -7,6 +7,7 @@ import math
import time
import statistics
from abc import abstractmethod
from typing import Union
from superbench.common.utils import logger, stdout_logger
from superbench.benchmarks import Precision, ModelAction, DistributedImpl, DistributedBackend, BenchmarkType, ReturnCode
@ -263,6 +264,10 @@ class ModelBenchmark(Benchmark):
# The unit of step time should be millisecond.
step_times = self._train_step(precision)
if isinstance(step_times, tuple):
step_times = step_times[0]
info = step_times[1]
self._process_info(ModelAction.TRAIN, precision, info)
step_times = self.__process_model_result(ModelAction.TRAIN, precision, step_times)
if not step_times:
self._result.set_return_code(ReturnCode.INVALID_BENCHMARK_RESULT)
@ -302,7 +307,7 @@ class ModelBenchmark(Benchmark):
return True
@abstractmethod
def _train_step(self, precision):
def _train_step(self, precision) -> Union[list, tuple]:
"""Define the training process.
Args:
@ -418,6 +423,7 @@ class ModelBenchmark(Benchmark):
precision_metric = {'float16': 'fp16', 'float32': 'fp32', 'float64': 'fp64', 'bfloat16': 'bf16'}
if precision.value in precision_metric.keys():
precision = precision_metric[precision.value]
metric_s = '{}_{}_step_time'.format(precision, model_action)
metric_t = '{}_{}_throughput'.format(precision, model_action)
# The unit of step time is millisecond, use it to calculate the throughput with the unit samples/sec.
@ -428,7 +434,7 @@ class ModelBenchmark(Benchmark):
if model_action == ModelAction.TRAIN:
step_times = self._sync_result(step_times)
if not step_times:
if not step_times or statistics.mean(step_times) < 0:
return None
if self._local_rank is None or self._global_rank == 0:
self._result.add_result(metric_s, statistics.mean(step_times))
@ -468,3 +474,13 @@ class ModelBenchmark(Benchmark):
step_time = statistics.mean(duration) if len(duration) < self._args.log_n_steps \
else statistics.mean(duration[-self._args.log_n_steps:])
stdout_logger.log(f'{self._name} - {precision.value}: step {curr_step}, step time {step_time}\n')
def _process_info(self, model_action, precision, info):
"""Process other info.
Args:
model_action (ModelAction): train or inference.
precision (Precision): precision of model.
info (dict): other info.
"""
pass

Просмотреть файл

@ -207,6 +207,23 @@ superbench:
seq_length: 224
batch_size: 1
precision: int8
megatron-gpt:
modes:
- name: mpi
proc_num: 1
node_num: all
parameters:
code_base: /opt/superbench/third_party/Megatron/Megatron-DeepSpeed/
dataset_url: https://huggingface.co/datasets/suolyer/pile_bookcorpus2/raw/main/test.json
batch_size: 2048
num_warmup: 0
num_steps: 10
precision:
- float16
- bfloat16
deepspeed: yes
sequence_parallel: yes
use_rotary_position_embeddings: yes
gpt_models:
<<: *default_pytorch_mode
models:

Просмотреть файл

@ -0,0 +1,357 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Tests for BERT model benchmarks."""
import os
from pathlib import Path
import statistics
from unittest import mock
import unittest
from superbench.benchmarks.context import ModelAction, Precision
from tests.helper import decorator
from superbench.benchmarks import BenchmarkRegistry, Platform, ReturnCode
from tests.helper.testcase import BenchmarkTestCase
class MegatronGPTTest(BenchmarkTestCase, unittest.TestCase):
"""Tests for IBBenchmark benchmark."""
@classmethod
def setUpClass(cls):
"""Hook method for setting up class fixture before running tests in the class."""
super().setUpClass()
cls.benchmark_name = 'megatron-gpt'
cls.createMockEnvs(cls)
cls.hostfile_path = os.path.join(cls._tmp_dir, 'hostfile')
@classmethod
def tearDownClass(cls):
"""Hook method for deconstructing the class fixture after running all tests in the class."""
for p in [
Path(cls._tmp_dir) / 'pretrain_gpt.py',
Path(cls._tmp_dir) / 'customdataset_text_document.bin',
Path(cls._tmp_dir) / 'customdataset_text_document.idx',
Path(cls._tmp_dir) / 'hostfile'
]:
if p.is_file():
p.unlink()
super().tearDownClass()
@mock.patch('superbench.benchmarks.model_benchmarks.MegatronGPT._generate_dataset')
def test_megatron_gpt_preprocess(self, mock_generate_dataset):
"""Test megatron-gpt benchmark."""
# Check registry.
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.CUDA)
assert (benchmark_cls)
benchmark = benchmark_cls(
self.benchmark_name,
parameters=f'--hostfile {self.hostfile_path} --batch_size 2048',
)
# Check init distribued setting.
os.environ['OMPI_COMM_WORLD_SIZE'] = '2'
os.environ['OMPI_COMM_WORLD_LOCAL_SIZE'] = '1'
os.environ['OMPI_COMM_WORLD_RANK'] = '0'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
with open(self.hostfile_path, 'w') as f:
f.write('host1\n')
f.write('host2\n')
f.write('host3\n')
mock_generate_dataset.return_value = True
ret = benchmark._preprocess()
assert (ret is False)
assert (benchmark.return_code == ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE)
benchmark = benchmark_cls(
self.benchmark_name,
parameters='--hostfile xxx --batch_size 2048',
)
mock_generate_dataset.return_value = True
ret = benchmark._preprocess()
assert (ret is False)
assert (benchmark.return_code == ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE)
os.environ['OMPI_COMM_WORLD_SIZE'] = '3'
os.environ['OMPI_COMM_WORLD_LOCAL_SIZE'] = '1'
benchmark = benchmark_cls(
self.benchmark_name,
parameters=f'--hostfile {self.hostfile_path} --batch_size 2048',
)
mock_generate_dataset.return_value = True
benchmark._preprocess()
self.assertEqual(benchmark._num_nodes, 3)
self.assertEqual(
benchmark._distributed_args,
'--nproc_per_node {0} --nnodes {1} --node_rank {2} --master_addr {3} --master_port {4}'.format(
benchmark._args.num_gpus, benchmark._num_nodes, 0, 'localhost', '12345'
)
)
# Check preprocessing.
# Negative cases
# no code_base
benchmark = benchmark_cls(
self.benchmark_name,
parameters=f'--code_base {self._tmp_dir} --hostfile {self.hostfile_path} --batch_size 2048',
)
mock_generate_dataset.return_value = True
ret = benchmark._preprocess()
assert (ret is False)
self.createMockFiles(['pretrain_gpt.py'])
# invalid micro batch size
benchmark = benchmark_cls(
self.benchmark_name,
parameters=f'--code_base {self._tmp_dir} --hostfile {self.hostfile_path} --micro_batch_size -1',
)
mock_generate_dataset.return_value = True
ret = benchmark._preprocess()
assert (ret is False)
benchmark = benchmark_cls(
self.benchmark_name,
parameters=f'--code_base {self._tmp_dir} --hostfile {self.hostfile_path} --micro_batch_size 4096',
)
mock_generate_dataset.return_value = True
ret = benchmark._preprocess()
assert (ret is False)
# invalid precision
benchmark = benchmark_cls(
self.benchmark_name,
parameters=f'--code_base {self._tmp_dir} --hostfile {self.hostfile_path} \
--batch_size 2048 --precision int8',
)
mock_generate_dataset.return_value = True
ret = benchmark._preprocess()
assert (ret is False)
# Positive cases
benchmark = benchmark_cls(
self.benchmark_name,
parameters=f'--code_base {self._tmp_dir} --hostfile {self.hostfile_path} --batch_size 2048',
)
mock_generate_dataset.return_value = True
ret = benchmark._preprocess()
assert (ret is True)
def test_megatron_gpt_dataset(self):
"""Test dataset genreation."""
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.CUDA)
assert (benchmark_cls)
os.environ['OMPI_COMM_WORLD_SIZE'] = '1'
os.environ['OMPI_COMM_WORLD_LOCAL_SIZE'] = '1'
os.environ['OMPI_COMM_WORLD_RANK'] = '0'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
# use existing dataset
self.createMockFiles(['customdataset_text_document.bin', 'customdataset_text_document.idx'])
benchmark = benchmark_cls(
self.benchmark_name,
parameters=f'--code_base /root/Megatron-DeepSpeed --data_home {self._tmp_dir} \
--batch_size 2048 --data_prefix customdataset_text_document',
)
ret = benchmark._preprocess()
ret = benchmark._generate_dataset()
assert (ret is True)
@mock.patch('superbench.benchmarks.model_benchmarks.MegatronGPT._generate_dataset')
def test_megatron_gpt_command(self, mock_generate_dataset):
"""Test command generation."""
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.CUDA)
assert (benchmark_cls)
os.environ['OMPI_COMM_WORLD_SIZE'] = '2'
os.environ['OMPI_COMM_WORLD_LOCAL_SIZE'] = '1'
os.environ['OMPI_COMM_WORLD_RANK'] = '0'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
with open(self.hostfile_path, 'w') as f:
f.write('host1\n')
f.write('host2\n')
# use url to process dataset
benchmark = benchmark_cls(
self.benchmark_name,
parameters=f'--code_base {self._tmp_dir} --hostfile {self.hostfile_path} \
--num_warmup 0 --num_steps 10 --batch_size 2048 --data_prefix dataset_text_document',
)
mock_generate_dataset.return_value = True
benchmark._preprocess()
benchmark._data_options = f'\
--vocab-file {self._tmp_dir}/gpt2-vocab.json \
--merge-file {self._tmp_dir}/gpt2-merges.txt \
--data-path {self._tmp_dir}/dataset_text_document \
--data-impl mmap'
script_path = str(Path(self._tmp_dir) / 'pretrain_gpt.py')
expected_command = 'torchrun {distributed_args} {script_path} \
--override-opt_param-scheduler \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--tensor-model-parallel-size 1 \
--init-method-std 0.009 \
--lr-decay-samples 43945312 \
--lr-warmup-samples 0 \
--lr-decay-style cosine \
--micro-batch-size 2 \
--global-batch-size 2048 \
--num-layers 32 \
--hidden-size 4096 \
--num-attention-heads 32 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--train-tokens 300000000000 \
--train-samples 20480 \
--lr 0.00012 \
--min-lr 1e-06 \
--split 949,50,1 \
--log-interval 1 \
--eval-interval 10 \
--eval-iters 0 \
--save-interval 10000 \
--weight-decay 0.1 \
--clip-grad 1.0 \
--hysteresis 2 \
--num-workers 8 \
--attention-dropout 0.0 \
--hidden-dropout 0.0 \
--optimizer adam \
--use-distributed-optimizer \
{precision} \
--seed 1234 {data_options}'
precision = Precision.FLOAT32
command = benchmark._megatron_command(precision)
self.assertEqual(
command,
expected_command.format(
precision='',
data_options=benchmark._data_options,
distributed_args=benchmark._distributed_args,
script_path=script_path
)
)
precision = Precision.FLOAT16
command = benchmark._megatron_command(precision)
self.assertEqual(
command,
expected_command.format(
precision='--fp16',
data_options=benchmark._data_options,
distributed_args=benchmark._distributed_args,
script_path=script_path
)
)
precision = Precision.BFLOAT16
command = benchmark._megatron_command(precision)
self.assertEqual(
command,
expected_command.format(
precision='--bf16',
data_options=benchmark._data_options,
distributed_args=benchmark._distributed_args,
script_path=script_path
)
)
os.environ['OMPI_COMM_WORLD_SIZE'] = '1'
benchmark = benchmark_cls(
self.benchmark_name,
parameters=f'--code_base {self._tmp_dir} --hostfile {self.hostfile_path} \
--num_warmup 0 --num_steps 10 --batch_size 2048 --data_prefix dataset_text_document --deepspeed',
)
mock_generate_dataset.return_value = True
benchmark._preprocess()
benchmark._data_options = f'\
--vocab-file {self._tmp_dir}/gpt2-vocab.json \
--merge-file {self._tmp_dir}/gpt2-merges.txt \
--data-path {self._tmp_dir}/dataset_text_document \
--data-impl mmap'
command = benchmark._megatron_command(Precision.BFLOAT16)
expected_command = 'deepspeed {script_path} \
--override-opt_param-scheduler \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--tensor-model-parallel-size 1 \
--init-method-std 0.009 \
--lr-decay-samples 43945312 \
--lr-warmup-samples 0 \
--lr-decay-style cosine \
--micro-batch-size 2 \
--global-batch-size 2048 \
--num-layers 32 \
--hidden-size 4096 \
--num-attention-heads 32 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--train-tokens 300000000000 \
--train-samples 20480 \
--lr 0.00012 \
--min-lr 1e-06 \
--split 949,50,1 \
--log-interval 1 \
--eval-interval 10 \
--eval-iters 0 \
--save-interval 10000 \
--weight-decay 0.1 \
--clip-grad 1.0 \
--hysteresis 2 \
--num-workers 8 \
--attention-dropout 0.0 \
--hidden-dropout 0.0 \
--optimizer adam \
--use-distributed-optimizer \
{precision} \
--seed 1234 {data_options} {deepseed_options}'
expect_ds_options = f'\
--deepspeed \
--deepspeed_config {benchmark._config_json_path} \
--zero-stage 1 \
--pipeline-model-parallel-size 1 --no-pipeline-parallel'
self.assertEqual(
command,
expected_command.format(
precision='--bf16',
data_options=benchmark._data_options,
script_path=script_path,
deepseed_options=expect_ds_options
)
)
@decorator.load_data('tests/data/megatron_deepspeed.log')
@mock.patch('superbench.benchmarks.model_benchmarks.MegatronGPT._generate_dataset')
def test_megatron_parse_log(self, raw_output, mock_generate_dataset):
"""Test parse log function."""
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.CUDA)
assert (benchmark_cls)
os.environ['OMPI_COMM_WORLD_SIZE'] = '1'
os.environ['OMPI_COMM_WORLD_LOCAL_SIZE'] = '1'
os.environ['OMPI_COMM_WORLD_RANK'] = '0'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
# use url to process dataset
benchmark = benchmark_cls(
self.benchmark_name,
parameters=f'--code_base {self._tmp_dir} --num_warmup 0 --num_steps 10 --batch_size 2048',
)
mock_generate_dataset.return_value = True
benchmark._preprocess()
benchmark._data_options = f'\
--vocab-file {self._tmp_dir}/gpt2-vocab.json \
--merge-file {self._tmp_dir}/gpt2-merges.txt \
--data-path {self._tmp_dir}/dataset_text_document \
--data-impl mmap'
iteration_times, tflops, mem_allocated, max_mem_allocated = benchmark._parse_log(raw_output)
assert (statistics.mean(iteration_times) == 75239.24)
assert (statistics.mean(tflops) == 149.136)
assert (statistics.mean(mem_allocated) == 17.54)
assert (statistics.mean(max_mem_allocated) == 66.97)
info = {'tflops': tflops, 'mem_allocated': mem_allocated, 'max_mem_allocated': max_mem_allocated}
benchmark._process_info(ModelAction.TRAIN, Precision.FLOAT16, info)
assert (benchmark.result is not None)
assert (benchmark.result['fp16_train_tflops'][0] == 149.136)
assert (benchmark.result['fp16_train_mem_allocated'][0] == 17.54)
assert (benchmark.result['fp16_train_max_mem_allocated'][0] == 66.97)

Разница между файлами не показана из-за своего большого размера Загрузить разницу

23
third_party/Makefile поставляемый
Просмотреть файл

@ -11,12 +11,12 @@ HPCX_HOME ?= /opt/hpcx
CUDA_VER ?= $(shell nvcc --version | grep 'release' | awk '{print $$6}' | cut -c2- | cut -d '.' -f1-2)
ROCBLAS_BRANCH ?= rocm-$(shell dpkg -l | grep 'rocm-dev ' | awk '{print $$3}' | cut -d '.' -f1-3)
.PHONY: all cuda rocm common cuda_cutlass cuda_bandwidthTest cuda_nccl_tests cuda_perftest rocm_perftest fio rocm_rccl_tests rocm_rocblas rocm_bandwidthTest gpcnet cuda_gpuburn cpu_stream cpu_hpl directx_amf_encoding_latency directx_amd rocm_hipblaslt
.PHONY: all cuda rocm common cuda_cutlass cuda_bandwidthTest cuda_nccl_tests cuda_perftest rocm_perftest fio rocm_rccl_tests rocm_rocblas rocm_bandwidthTest gpcnet cuda_gpuburn cpu_stream cpu_hpl directx_amf_encoding_latency directx_amd rocm_hipblaslt megatron_lm megatron_deepspeed
# Build all targets.
all: cuda rocm
cuda: common cuda_cutlass cuda_bandwidthTest cuda_nccl_tests cuda_perftest gpcnet cuda_gpuburn
rocm: common rocm_perftest rocm_rccl_tests rocm_rocblas rocm_bandwidthTest rocm_hipblaslt
cuda: common cuda_cutlass cuda_bandwidthTest cuda_nccl_tests cuda_perftest gpcnet cuda_gpuburn megatron_lm megatron_deepspeed
rocm: common rocm_perftest rocm_rccl_tests rocm_rocblas rocm_bandwidthTest rocm_hipblaslt megatron_deepspeed
cpu: common cpu_perftest
common: cpu_hpl cpu_stream fio
directx_amd: directx_amf_encoding_latency
@ -171,3 +171,20 @@ directx_amf_encoding_latency:
del vs_buildtools.exe && echo "Deleted vs_buildtools.exe" && \
"C:\temp\BuildTools\MSBuild\Current\Bin\MSBuild.exe" "AMF\amf\public\samples\CPPSamples_vs2019.sln" /t:EncoderLatency /p:Platform=x64 /p:Configuration=Release /p:OutDir="%SB_MICRO_PATH%\bin" \
)
# Install Megatron-LM
megatron_lm:
if [ ! -d "Megatron/Megatron-LM" ]; then \
git clone "https://github.com/NVIDIA/Megatron-LM.git" "Megatron/Megatron-LM"; \
fi
cd Megatron && \
python -m pip install -r requirements.txt
# Install Megatron-DeepSpeed
megatron_deepspeed:
if [ ! -d "Megatron/Megatron-DeepSpeed" ]; then \
git clone "https://github.com/microsoft/Megatron-DeepSpeed.git" "Megatron/Megatron-DeepSpeed"; \
fi
cd Megatron && \
python -m pip install -r requirements.txt && \
python -m pip install DeepSpeed

13
third_party/Megatron/requirements.txt поставляемый Normal file
Просмотреть файл

@ -0,0 +1,13 @@
nltk
parameterized
pybind11
regex
six
# versions from HF transformers
black==21.4b0
isort>=5.5.4
tqdm
sentencepiece
wandb
einops
typing_extensions==4.5.0