Merge branch 'xiongyf/support-te-fp8' of https://github.com/microsoft/superbenchmark into xiongyf/support-te-fp8
This commit is contained in:
Коммит
a6da958d6f
|
@ -409,6 +409,19 @@ Test the performance of large scale matmul operation with multiple GPUs:
|
||||||
| pytorch-sharding-matmul/allreduce_time | time (ms) | Time of sharding matmul using allreduce. |
|
| pytorch-sharding-matmul/allreduce_time | time (ms) | Time of sharding matmul using allreduce. |
|
||||||
| pytorch-sharding-matmul/allgather_time | time (ms) | Time of sharding matmul using allgather. |
|
| pytorch-sharding-matmul/allgather_time | time (ms) | Time of sharding matmul using allgather. |
|
||||||
|
|
||||||
|
### `dist-inference`
|
||||||
|
|
||||||
|
#### Introduction
|
||||||
|
|
||||||
|
Test the performance of distributed model inference.
|
||||||
|
|
||||||
|
#### Metrics
|
||||||
|
|
||||||
|
| Name | Unit | Description |
|
||||||
|
|-------------------------------------------------|-----------|-------------------------------------------------------|
|
||||||
|
| pytorch-dist-inference/step_times | time (ms) | Average time of model inference runs. |
|
||||||
|
| pytorch-dist-inference/step_times_${percentile} | time (ms) | Tail (50,90,95,99,99.9) time of model inference runs. |
|
||||||
|
|
||||||
## Storage Benchmarks
|
## Storage Benchmarks
|
||||||
|
|
||||||
### `disk-benchmark`
|
### `disk-benchmark`
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
"""Micro benchmark example for distributed inference with pytorch.
|
||||||
|
|
||||||
|
Commands to run:
|
||||||
|
python3 -m torch.distributed.launch --nproc_per_node=8 examples/benchmarks/dist_inference.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from superbench.benchmarks import Framework, BenchmarkRegistry
|
||||||
|
from superbench.common.utils import logger
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
context = BenchmarkRegistry.create_benchmark_context('dist-inference', parameters='', framework=Framework.PYTORCH)
|
||||||
|
|
||||||
|
benchmark = BenchmarkRegistry.launch_benchmark(context)
|
||||||
|
if benchmark:
|
||||||
|
logger.info(
|
||||||
|
'benchmark: {}, return code: {}, result: {}'.format(
|
||||||
|
benchmark.name, benchmark.return_code, benchmark.result
|
||||||
|
)
|
||||||
|
)
|
|
@ -15,6 +15,7 @@ from superbench.benchmarks.micro_benchmarks.cuda_memory_bw_performance import Cu
|
||||||
from superbench.benchmarks.micro_benchmarks.cuda_nccl_bw_performance import CudaNcclBwBenchmark
|
from superbench.benchmarks.micro_benchmarks.cuda_nccl_bw_performance import CudaNcclBwBenchmark
|
||||||
from superbench.benchmarks.micro_benchmarks.cudnn_function import CudnnBenchmark
|
from superbench.benchmarks.micro_benchmarks.cudnn_function import CudnnBenchmark
|
||||||
from superbench.benchmarks.micro_benchmarks.disk_performance import DiskBenchmark
|
from superbench.benchmarks.micro_benchmarks.disk_performance import DiskBenchmark
|
||||||
|
from superbench.benchmarks.micro_benchmarks.dist_inference import DistInference
|
||||||
from superbench.benchmarks.micro_benchmarks.cpu_memory_bw_latency_performance import CpuMemBwLatencyBenchmark
|
from superbench.benchmarks.micro_benchmarks.cpu_memory_bw_latency_performance import CpuMemBwLatencyBenchmark
|
||||||
from superbench.benchmarks.micro_benchmarks.cpu_stream_performance import CpuStreamBenchmark
|
from superbench.benchmarks.micro_benchmarks.cpu_stream_performance import CpuStreamBenchmark
|
||||||
from superbench.benchmarks.micro_benchmarks.cpu_hpl_performance import CpuHplBenchmark
|
from superbench.benchmarks.micro_benchmarks.cpu_hpl_performance import CpuHplBenchmark
|
||||||
|
@ -43,6 +44,7 @@ __all__ = [
|
||||||
'CudaNcclBwBenchmark',
|
'CudaNcclBwBenchmark',
|
||||||
'CudnnBenchmark',
|
'CudnnBenchmark',
|
||||||
'DiskBenchmark',
|
'DiskBenchmark',
|
||||||
|
'DistInference',
|
||||||
'GPCNetBenchmark',
|
'GPCNetBenchmark',
|
||||||
'GemmFlopsBenchmark',
|
'GemmFlopsBenchmark',
|
||||||
'GpuBurnBenchmark',
|
'GpuBurnBenchmark',
|
||||||
|
|
|
@ -0,0 +1,455 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
"""Module of the distributed inference benchmark."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from superbench.common.utils import logger
|
||||||
|
from superbench.benchmarks import DistributedImpl, DistributedBackend, BenchmarkRegistry, ReturnCode, Precision
|
||||||
|
from superbench.benchmarks.micro_benchmarks import MicroBenchmark
|
||||||
|
from superbench.benchmarks.context import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class ComputationKernelType(Enum):
|
||||||
|
"""The Enum class representing different computation kernel type."""
|
||||||
|
ADDMM = 'addmm'
|
||||||
|
MATMUL = 'matmul'
|
||||||
|
MUL = 'mul'
|
||||||
|
|
||||||
|
|
||||||
|
class CommunicationKernelType(Enum):
|
||||||
|
"""The Enum class representing different communication kernel type."""
|
||||||
|
ALLGATHER = 'allgather'
|
||||||
|
ALLREDUCE = 'allreduce'
|
||||||
|
ALLTOALL = 'alltoall'
|
||||||
|
|
||||||
|
|
||||||
|
class ActivationKernelType(Enum):
|
||||||
|
"""The Enum class representing different activation kernel type."""
|
||||||
|
RELU = 'relu'
|
||||||
|
SIGMOID = 'sigmoid'
|
||||||
|
TANH = 'tanh'
|
||||||
|
|
||||||
|
|
||||||
|
class DistInferenceModel(torch.nn.Module):
|
||||||
|
"""The model class for distributed inference benchmark."""
|
||||||
|
def __init__(
|
||||||
|
self, input_size, hidden_size, num_layers, computation, communication, activation, precision, num_ranks, device
|
||||||
|
):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_size (int): input data dimension.
|
||||||
|
hidden_size (int): hidden layer dimension.
|
||||||
|
num_layers (int): number of layers in the model.
|
||||||
|
computation (ComputationKernelType): type of computation kernel of this model.
|
||||||
|
communication (CommunicationKernelType): type of communication kernel of this model.
|
||||||
|
activation (ActivationKernelType): type of activation kernel of this model.
|
||||||
|
precision (Precision): data type of this model.
|
||||||
|
num_ranks (int): number of ranks in this model runs.
|
||||||
|
device (torch.device): device this model runs on.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.input_size = input_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.weights = torch.rand(
|
||||||
|
self.input_size, self.hidden_size, dtype=getattr(torch, precision.value), device=device
|
||||||
|
)
|
||||||
|
self.bias = torch.rand(self.hidden_size, dtype=getattr(torch, precision.value), device=device)
|
||||||
|
self.num_ranks = num_ranks
|
||||||
|
self.step_times = []
|
||||||
|
|
||||||
|
self.__init_computation_kernels(computation)
|
||||||
|
self.__init_communication_kernels(communication)
|
||||||
|
self.__init_activation_kernels(activation)
|
||||||
|
|
||||||
|
def __init_computation_kernels(self, computation):
|
||||||
|
"""Select computation kernel according to option.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
computation (ComputationKernelType): the type of the computation kernel to run.
|
||||||
|
"""
|
||||||
|
self.computation_kernel = None
|
||||||
|
if computation == ComputationKernelType.ADDMM:
|
||||||
|
self.computation_kernel = lambda x: torch.addmm(self.bias, x, self.weights)
|
||||||
|
elif computation == ComputationKernelType.MATMUL:
|
||||||
|
self.computation_kernel = lambda x: torch.matmul(x, self.weights)
|
||||||
|
elif computation == ComputationKernelType.MUL:
|
||||||
|
self.computation_kernel = lambda x: torch.mul(x, x)
|
||||||
|
|
||||||
|
def __init_communication_kernels(self, communication):
|
||||||
|
"""Select communication kernel according to option.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
communication (CommunicationKernelType): the type of the communication kernel to run.
|
||||||
|
"""
|
||||||
|
self.communication_kernel = None
|
||||||
|
if communication == CommunicationKernelType.ALLGATHER:
|
||||||
|
self.communication_kernel = self.__all_gather_wrapper
|
||||||
|
elif communication == CommunicationKernelType.ALLREDUCE:
|
||||||
|
self.communication_kernel = self.__all_reduce_wrapper
|
||||||
|
elif communication == CommunicationKernelType.ALLTOALL:
|
||||||
|
self.communication_kernel = self.__all_to_all_wrapper
|
||||||
|
|
||||||
|
def __init_activation_kernels(self, activation):
|
||||||
|
"""Select activation kernel according to option.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
activation (ActivationKernelType): the type of the activation kernel to run.
|
||||||
|
"""
|
||||||
|
self.activation_kernel = None
|
||||||
|
if activation == ActivationKernelType.RELU:
|
||||||
|
self.activation_kernel = F.relu
|
||||||
|
elif activation == ActivationKernelType.SIGMOID:
|
||||||
|
self.activation_kernel = F.sigmoid
|
||||||
|
elif activation == ActivationKernelType.TANH:
|
||||||
|
self.activation_kernel = F.tanh
|
||||||
|
|
||||||
|
def __all_gather_wrapper(self, x):
|
||||||
|
"""All-gather wrapper with output initialization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): input.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Tensor after all-gather.
|
||||||
|
"""
|
||||||
|
output = torch.empty_like([x.shape[0] * self.num_ranks] + list(x.shape[1:]))
|
||||||
|
dist.all_gather_into_tensor(output, x)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def __all_reduce_wrapper(self, x):
|
||||||
|
"""All-reduce wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): input.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Tensor after all-reduce.
|
||||||
|
"""
|
||||||
|
dist.all_reduce(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __all_to_all_wrapper(self, x):
|
||||||
|
"""All-to-all wrapper with output initialization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): input.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Tensor after all-to-all.
|
||||||
|
"""
|
||||||
|
output = torch.empty_like(x)
|
||||||
|
dist.all_to_all_single(output, x)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Do forward loops.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): input.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Tensor after the whole inference process.
|
||||||
|
"""
|
||||||
|
activation_out = None
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
computation_out = self.computation_kernel(x)
|
||||||
|
communication_out = self.communication_kernel(computation_out)
|
||||||
|
activation_out = self.activation_kernel(communication_out)
|
||||||
|
return activation_out
|
||||||
|
|
||||||
|
|
||||||
|
class DistInference(MicroBenchmark):
|
||||||
|
"""The base class of micro-benchmarks."""
|
||||||
|
def __init__(self, name, parameters=''):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): benchmark name.
|
||||||
|
parameters (str): benchmark parameters.
|
||||||
|
"""
|
||||||
|
super().__init__(name, parameters)
|
||||||
|
self.__world_size = 1
|
||||||
|
self.__local_rank = 0
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
self.__device = None
|
||||||
|
self.__cuda_available = False
|
||||||
|
|
||||||
|
def __timer(self):
|
||||||
|
"""Returns the current time which ensures all previous CUDA events have been finished.
|
||||||
|
|
||||||
|
If there is no GPU present, this defaults to `time.time()`; otherwise it will
|
||||||
|
synchronize CUDA before measuring the time.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Current time in second.
|
||||||
|
"""
|
||||||
|
if self.__cuda_available:
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
return time.time()
|
||||||
|
|
||||||
|
def add_parser_arguments(self):
|
||||||
|
"""Add the specified arguments."""
|
||||||
|
super().add_parser_arguments()
|
||||||
|
|
||||||
|
self._parser.add_argument(
|
||||||
|
'--batch_size',
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
required=False,
|
||||||
|
help='Batch size.',
|
||||||
|
)
|
||||||
|
self._parser.add_argument(
|
||||||
|
'--input_size',
|
||||||
|
type=int,
|
||||||
|
default=1024,
|
||||||
|
required=False,
|
||||||
|
help='Input dimension size.',
|
||||||
|
)
|
||||||
|
self._parser.add_argument(
|
||||||
|
'--hidden_size',
|
||||||
|
type=int,
|
||||||
|
default=1024,
|
||||||
|
required=False,
|
||||||
|
help='Hidden size.',
|
||||||
|
)
|
||||||
|
self._parser.add_argument(
|
||||||
|
'--num_layers',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
required=False,
|
||||||
|
help='Number of compute-communicate-activate layers.',
|
||||||
|
)
|
||||||
|
self._parser.add_argument(
|
||||||
|
'--computation_kernel',
|
||||||
|
type=ComputationKernelType,
|
||||||
|
default=ComputationKernelType.MATMUL,
|
||||||
|
required=False,
|
||||||
|
help='Computation kernel type. E.g. {}.'.format(' '.join(ComputationKernelType.get_values())),
|
||||||
|
)
|
||||||
|
self._parser.add_argument(
|
||||||
|
'--communication_kernel',
|
||||||
|
type=CommunicationKernelType,
|
||||||
|
default=CommunicationKernelType.ALLREDUCE,
|
||||||
|
required=False,
|
||||||
|
help='Communication kernel type. E.g. {}.'.format(' '.join(CommunicationKernelType.get_values())),
|
||||||
|
)
|
||||||
|
self._parser.add_argument(
|
||||||
|
'--activation_kernel',
|
||||||
|
type=ActivationKernelType,
|
||||||
|
default=ActivationKernelType.RELU,
|
||||||
|
required=False,
|
||||||
|
help='Activation kernel type. E.g. {}.'.format(' '.join(ActivationKernelType.get_values())),
|
||||||
|
)
|
||||||
|
self._parser.add_argument(
|
||||||
|
'--precision',
|
||||||
|
type=Precision,
|
||||||
|
default=Precision.FLOAT32,
|
||||||
|
required=False,
|
||||||
|
help='Model precision. E.g. {}.'.format(' '.join(Precision.get_values())),
|
||||||
|
)
|
||||||
|
self._parser.add_argument(
|
||||||
|
'--num_warmup',
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
required=False,
|
||||||
|
help='Number of warmup steps.',
|
||||||
|
)
|
||||||
|
self._parser.add_argument(
|
||||||
|
'--num_steps',
|
||||||
|
type=int,
|
||||||
|
default=10000,
|
||||||
|
required=False,
|
||||||
|
help='Number of test steps.',
|
||||||
|
)
|
||||||
|
self._parser.add_argument(
|
||||||
|
'--distributed_impl',
|
||||||
|
type=DistributedImpl,
|
||||||
|
default=DistributedImpl.DDP,
|
||||||
|
required=False,
|
||||||
|
help='Distributed implementations. E.g. {}.'.format(' '.join(DistributedImpl.get_values())),
|
||||||
|
)
|
||||||
|
self._parser.add_argument(
|
||||||
|
'--distributed_backend',
|
||||||
|
type=DistributedBackend,
|
||||||
|
default=DistributedBackend.NCCL,
|
||||||
|
required=False,
|
||||||
|
help='Distributed backends. E.g. {}.'.format(' '.join(DistributedBackend.get_values())),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _preprocess(self):
|
||||||
|
"""Preprocess/preparation operations before the benchmarking.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
True if _preprocess() succeed.
|
||||||
|
"""
|
||||||
|
if not super()._preprocess():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self._args.distributed_impl != DistributedImpl.DDP:
|
||||||
|
self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE)
|
||||||
|
logger.error(
|
||||||
|
'Unsupported distributed implementation - model: {}, distributed implementation: {}.'.format(
|
||||||
|
self._name, self._args.distributed_impl
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
torch.distributed.init_process_group(backend=self._args.distributed_backend.value)
|
||||||
|
self.__world_size = int(os.environ['WORLD_SIZE'])
|
||||||
|
self.__local_rank = int(os.environ['LOCAL_RANK'])
|
||||||
|
except BaseException as e:
|
||||||
|
self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE)
|
||||||
|
torch.distributed.destroy_process_group()
|
||||||
|
logger.error('Initialize distributed env failed - benchmark: {}, message: {}.'.format(self._name, str(e)))
|
||||||
|
return False
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.set_device(self.__local_rank)
|
||||||
|
self.__device = torch.device('cuda:{}'.format(self.__local_rank))
|
||||||
|
self.__cuda_available = True
|
||||||
|
else:
|
||||||
|
self.__device = torch.device('cpu:{}'.format(self.__local_rank))
|
||||||
|
self.__cuda_available = False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _prepare_model(
|
||||||
|
self, input_size, hidden_size, num_layers, computation, communication, activation, precision, num_ranks
|
||||||
|
):
|
||||||
|
"""Prepare model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_size (int): input data dimension.
|
||||||
|
hidden_size (int): hidden layer dimension.
|
||||||
|
num_layers (int): number of layers in the model.
|
||||||
|
computation (ComputationKernelType): type of computation kernel of this model.
|
||||||
|
communication (CommunicationKernelType): type of communication kernel of this model.
|
||||||
|
activation (ActivationKernelType): type of activation kernel of this model.
|
||||||
|
precision (Precision): data type of this model.
|
||||||
|
num_ranks (int): number of ranks in this model runs.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Model prepared.
|
||||||
|
"""
|
||||||
|
model = DistInferenceModel(
|
||||||
|
input_size, hidden_size, num_layers, computation, communication, activation, precision, num_ranks,
|
||||||
|
self.__device
|
||||||
|
)
|
||||||
|
model = model.to(dtype=getattr(torch, precision.value))
|
||||||
|
if self.__cuda_available:
|
||||||
|
model = model.cuda()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _run_model(self, model, batch_size, input_size, precision, device, num_warmup, num_steps):
|
||||||
|
"""Run model and collect step times.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): model to run.
|
||||||
|
batch_size (int): batch size of input data.
|
||||||
|
input_size (int): input data dimension.
|
||||||
|
precision (Precision): data type of this model.
|
||||||
|
device (torch.device): device this model runs on.
|
||||||
|
num_warmup (int): number of warm-up runs.
|
||||||
|
num_steps (int): number of test runs.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Model step times collected.
|
||||||
|
"""
|
||||||
|
data = torch.rand(batch_size, input_size, dtype=getattr(torch, precision.value), device=self.__device)
|
||||||
|
|
||||||
|
# warm up
|
||||||
|
for i in range(num_warmup):
|
||||||
|
model(data)
|
||||||
|
|
||||||
|
# run and collect results
|
||||||
|
step_times = [0.] * num_steps
|
||||||
|
for i in range(self._args.num_steps):
|
||||||
|
start = self.__timer()
|
||||||
|
model(data)
|
||||||
|
end = self.__timer()
|
||||||
|
step_times[i] = (end - start) * 1000
|
||||||
|
|
||||||
|
return step_times
|
||||||
|
|
||||||
|
def _process_data(self, step_times):
|
||||||
|
"""Process data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step_times (List[float]): Model step times collected.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
True if _process_data succeeds.
|
||||||
|
"""
|
||||||
|
if not self._process_numeric_result('step_times', step_times, cal_percentile=True):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _benchmark(self):
|
||||||
|
"""Implementation for benchmarking.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
True if _benchmark succeeds.
|
||||||
|
"""
|
||||||
|
batch_size = self._args.batch_size
|
||||||
|
input_size = self._args.input_size
|
||||||
|
hidden_size = self._args.hidden_size
|
||||||
|
num_layers = self._args.num_layers
|
||||||
|
computation = self._args.computation_kernel
|
||||||
|
communication = self._args.communication_kernel
|
||||||
|
activation = self._args.activation_kernel
|
||||||
|
precision = self._args.precision
|
||||||
|
num_warmup = self._args.num_warmup
|
||||||
|
num_steps = self._args.num_steps
|
||||||
|
|
||||||
|
if self.__local_rank == 0:
|
||||||
|
logger.info(
|
||||||
|
'Distributed Inference - using {} GPUs: '
|
||||||
|
'batch_size={}, input_size={}, hidden_size={}, num_layers={}, '
|
||||||
|
'computation_kernel={}, communication_kernel={}, activation_kernel={}, precision={}, '
|
||||||
|
'num_warmup={} num_steps={}'.format(
|
||||||
|
self.__world_size, batch_size, input_size, hidden_size, num_layers, computation, communication,
|
||||||
|
activation, precision, num_warmup, num_steps
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare model
|
||||||
|
model = self._prepare_model(
|
||||||
|
input_size, hidden_size, num_layers, computation, communication, activation, precision, self.__world_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run model
|
||||||
|
step_times = self._run_model(model, batch_size, input_size, precision, self.__device, num_warmup, num_steps)
|
||||||
|
|
||||||
|
# Process data and return
|
||||||
|
return self._process_data(step_times)
|
||||||
|
|
||||||
|
def _postprocess(self):
|
||||||
|
"""Postprocess/cleanup operations after the benchmarking.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
True if _postprocess succeeds.
|
||||||
|
"""
|
||||||
|
if not super()._postprocess():
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
torch.distributed.destroy_process_group()
|
||||||
|
except BaseException as e:
|
||||||
|
self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_DESTROY_FAILURE)
|
||||||
|
logger.error('Post process failed - benchmark: {}, message: {}.'.format(self._name, str(e)))
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
BenchmarkRegistry.register_benchmark('pytorch-dist-inference', DistInference, parameters='')
|
|
@ -26,6 +26,13 @@ superbench:
|
||||||
node_num: 1
|
node_num: 1
|
||||||
frameworks:
|
frameworks:
|
||||||
- pytorch
|
- pytorch
|
||||||
|
dist_inference_pytorch_mode: &dist_inference_pytorch_mode
|
||||||
|
modes:
|
||||||
|
- name: torch.distributed
|
||||||
|
proc_num: 8
|
||||||
|
node_num: 1
|
||||||
|
env:
|
||||||
|
NCCL_ASYNC_ERROR_HANDLING: '0'
|
||||||
common_model_config: &common_model_config
|
common_model_config: &common_model_config
|
||||||
duration: 0
|
duration: 0
|
||||||
num_warmup: 64
|
num_warmup: 64
|
||||||
|
@ -159,6 +166,8 @@ superbench:
|
||||||
<<: *default_pytorch_mode
|
<<: *default_pytorch_mode
|
||||||
computation-communication-overlap:
|
computation-communication-overlap:
|
||||||
<<: *default_pytorch_mode
|
<<: *default_pytorch_mode
|
||||||
|
dist-inference:
|
||||||
|
<<: *dist_inference_pytorch_mode
|
||||||
ib-traffic:
|
ib-traffic:
|
||||||
enable: false
|
enable: false
|
||||||
modes:
|
modes:
|
||||||
|
|
|
@ -22,6 +22,13 @@ superbench:
|
||||||
node_num: 1
|
node_num: 1
|
||||||
frameworks:
|
frameworks:
|
||||||
- pytorch
|
- pytorch
|
||||||
|
dist_inference_pytorch_mode: &dist_inference_pytorch_mode
|
||||||
|
modes:
|
||||||
|
- name: torch.distributed
|
||||||
|
proc_num: 8
|
||||||
|
node_num: 1
|
||||||
|
env:
|
||||||
|
NCCL_ASYNC_ERROR_HANDLING: '0'
|
||||||
common_model_config: &common_model_config
|
common_model_config: &common_model_config
|
||||||
duration: 0
|
duration: 0
|
||||||
num_warmup: 64
|
num_warmup: 64
|
||||||
|
@ -165,6 +172,8 @@ superbench:
|
||||||
<<: *default_pytorch_mode
|
<<: *default_pytorch_mode
|
||||||
computation-communication-overlap:
|
computation-communication-overlap:
|
||||||
<<: *default_pytorch_mode
|
<<: *default_pytorch_mode
|
||||||
|
dist-inference:
|
||||||
|
<<: *dist_inference_pytorch_mode
|
||||||
ib-traffic:
|
ib-traffic:
|
||||||
enable: false
|
enable: false
|
||||||
modes:
|
modes:
|
||||||
|
|
|
@ -0,0 +1,96 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
"""Tests for distributed inference benchmark."""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from tests.helper import decorator
|
||||||
|
import tests.benchmarks.utils as utils
|
||||||
|
from superbench.benchmarks \
|
||||||
|
import BenchmarkRegistry, Framework, BenchmarkType, ReturnCode, Precision, DistributedImpl, DistributedBackend
|
||||||
|
from superbench.benchmarks.micro_benchmarks.dist_inference \
|
||||||
|
import DistInference, ComputationKernelType, CommunicationKernelType, ActivationKernelType
|
||||||
|
from superbench.common.utils import network
|
||||||
|
|
||||||
|
|
||||||
|
# TODO - replace unittest.skip("no multiple GPUs") to decorator of skipIfNoMultiGPUS
|
||||||
|
@unittest.skip('no multiple GPUs')
|
||||||
|
@decorator.cuda_test
|
||||||
|
@decorator.pytorch_test
|
||||||
|
def test_pytorch_dist_inference_normal():
|
||||||
|
"""Test pytorch-dist-inference benchmark on distributed normal case."""
|
||||||
|
context = BenchmarkRegistry.create_benchmark_context('dist-inference', parameters='', framework=Framework.PYTORCH)
|
||||||
|
world_size = 2
|
||||||
|
assert (BenchmarkRegistry.is_benchmark_context_valid(context))
|
||||||
|
results = utils.simulated_ddp_distributed_benchmark(context, world_size)
|
||||||
|
assert (results)
|
||||||
|
for benchmark in results:
|
||||||
|
# Check basic information.
|
||||||
|
assert (benchmark)
|
||||||
|
assert (isinstance(benchmark, DistInference))
|
||||||
|
assert (benchmark.name == 'pytorch-dist-inference')
|
||||||
|
assert (benchmark.type == BenchmarkType.MICRO)
|
||||||
|
|
||||||
|
# Check predefined parameters of dist-inference benchmark.
|
||||||
|
assert (benchmark._args.batch_size == 64)
|
||||||
|
assert (benchmark._args.input_size == 1024)
|
||||||
|
assert (benchmark._args.hidden_size == 1024)
|
||||||
|
assert (benchmark._args.num_layers == 1)
|
||||||
|
assert (benchmark._args.computation_kernel == ComputationKernelType.MATMUL)
|
||||||
|
assert (benchmark._args.communication_kernel == CommunicationKernelType.ALLREDUCE)
|
||||||
|
assert (benchmark._args.activation_kernel == ActivationKernelType.RELU)
|
||||||
|
assert (benchmark._args.precision == Precision.FLOAT32)
|
||||||
|
assert (benchmark._args.num_warmup == 50)
|
||||||
|
assert (benchmark._args.num_steps == 10000)
|
||||||
|
assert (benchmark._args.distributed_impl == DistributedImpl.DDP)
|
||||||
|
assert (benchmark._args.distributed_backend == DistributedBackend.NCCL)
|
||||||
|
|
||||||
|
# Check results and metrics.
|
||||||
|
assert (benchmark.run_count == 1)
|
||||||
|
assert (benchmark.return_code == ReturnCode.SUCCESS)
|
||||||
|
# step_times
|
||||||
|
assert (len(benchmark.raw_data) == 1)
|
||||||
|
# return code + (avg, 50th, 90th, 95th, 99th, 99.9th)
|
||||||
|
assert (len(benchmark.result) == 7)
|
||||||
|
|
||||||
|
|
||||||
|
@decorator.cuda_test
|
||||||
|
@decorator.pytorch_test
|
||||||
|
def test_pytorch_dist_inference_fake_distributed():
|
||||||
|
"""Test pytorch-dist-inference benchmark on single gpu."""
|
||||||
|
context = BenchmarkRegistry.create_benchmark_context('dist-inference', parameters='', framework=Framework.PYTORCH)
|
||||||
|
port = network.get_free_port()
|
||||||
|
assert (port)
|
||||||
|
utils.setup_simulated_ddp_distributed_env(1, 0, port)
|
||||||
|
benchmark = BenchmarkRegistry.launch_benchmark(context)
|
||||||
|
|
||||||
|
# Check basic information.
|
||||||
|
assert (benchmark)
|
||||||
|
assert (isinstance(benchmark, DistInference))
|
||||||
|
assert (benchmark.name == 'pytorch-dist-inference')
|
||||||
|
assert (benchmark.type == BenchmarkType.MICRO)
|
||||||
|
|
||||||
|
# Check predefined parameters of dist-inference benchmark.
|
||||||
|
assert (benchmark._args.batch_size == 64)
|
||||||
|
assert (benchmark._args.input_size == 1024)
|
||||||
|
assert (benchmark._args.hidden_size == 1024)
|
||||||
|
assert (benchmark._args.num_layers == 1)
|
||||||
|
assert (benchmark._args.computation_kernel == ComputationKernelType.MATMUL)
|
||||||
|
assert (benchmark._args.communication_kernel == CommunicationKernelType.ALLREDUCE)
|
||||||
|
assert (benchmark._args.activation_kernel == ActivationKernelType.RELU)
|
||||||
|
assert (benchmark._args.precision == Precision.FLOAT32)
|
||||||
|
assert (benchmark._args.num_warmup == 50)
|
||||||
|
assert (benchmark._args.num_steps == 10000)
|
||||||
|
assert (benchmark._args.distributed_impl == DistributedImpl.DDP)
|
||||||
|
assert (benchmark._args.distributed_backend == DistributedBackend.NCCL)
|
||||||
|
|
||||||
|
# Check results and metrics.
|
||||||
|
assert (benchmark.run_count == 1)
|
||||||
|
assert (benchmark.return_code == ReturnCode.SUCCESS)
|
||||||
|
# step_times
|
||||||
|
assert (len(benchmark.raw_data) == 1)
|
||||||
|
# return code + (avg, 50th, 90th, 95th, 99th, 99.9th)
|
||||||
|
assert (len(benchmark.result) == 7)
|
||||||
|
|
||||||
|
utils.clean_simulated_ddp_distributed_env()
|
Загрузка…
Ссылка в новой задаче