зеркало из https://github.com/microsoft/DeepSpeed.git
Comms Benchmarks (#2040)
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
This commit is contained in:
Родитель
76ea0534c1
Коммит
9b70ce56e7
|
@ -39,7 +39,7 @@ repos:
|
|||
name: check-torchdist
|
||||
entry: ./scripts/check-torchdist.py
|
||||
language: script
|
||||
exclude: ^(deepspeed/comm/|docs/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py)
|
||||
exclude: ^(deepspeed/comm/|docs/|benchmarks/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py)
|
||||
# Specific deepspeed/ files are excluded for now until we wrap ProcessGroup in deepspeed.comm
|
||||
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
# Running Communication Benchmarks
|
||||
|
||||
|
||||
To run benchmarks, there are two options:
|
||||
|
||||
1. Run a single communication operation:
|
||||
|
||||
For example, run with a single large message size:
|
||||
<pre>
|
||||
deepspeed all_reduce.py
|
||||
</pre>
|
||||
|
||||
Scan across message sizes:
|
||||
<pre>
|
||||
deepspeed all_reduce.py --scan
|
||||
</pre>
|
||||
|
||||
Each individual communication operation's benchmarks have separate benchmarking options. For `all_reduce.py`, for example:
|
||||
|
||||
<pre>
|
||||
usage: ds_bench [-h] [--local_rank LOCAL_RANK] [--trials TRIALS] [--warmup WARMUP] [--maxsize MAXSIZE] [--async-op] [--bw-unit {Gbps,GBps}] [--backend {nccl}] [--dist {deepspeed,torch}] [--scan] [--dtype DTYPE] [--mem-factor MEM_FACTOR] [--debug]
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--local_rank LOCAL_RANK
|
||||
--trials TRIALS Number of timed iterations
|
||||
--warmup WARMUP Number of warmup (non-timed) iterations
|
||||
--maxsize MAXSIZE Max message size as a power of 2
|
||||
--async-op Enables non-blocking communication
|
||||
--bw-unit {Gbps,GBps}
|
||||
--backend {nccl} Communication library to use
|
||||
--dist {deepspeed,torch}
|
||||
Distributed DL framework to use
|
||||
--scan Enables scanning all message sizes
|
||||
--dtype DTYPE PyTorch tensor dtype
|
||||
--mem-factor MEM_FACTOR
|
||||
Proportion of max available GPU memory to use for single-size evals
|
||||
--debug Enables alltoall debug prints
|
||||
</pre>
|
||||
|
||||
2. Run all available communication benchmarks:
|
||||
|
||||
<pre>
|
||||
deepspeed run_all.py
|
||||
</pre>
|
||||
|
||||
Like the individual benchmarks, `run_all.py` supports scanning arguments for the max message size, bw-unit, etc. Simply pass the desired arguments to `run_all.py` and they'll be propagated to each comm op.
|
||||
|
||||
Note that `ds_bench` is a pre-packaged wrapper around `run_all.py`. Users can pass the same arguments as well:
|
||||
|
||||
<pre>
|
||||
<path to deepspeed>/bin/ds_bench --scan --trials=10
|
||||
</pre>
|
||||
|
||||
|
||||
# Adding Communication Benchmarks
|
||||
|
||||
To add new communication benchmarks, follow this general procedure:
|
||||
|
||||
1. Copy a similar benchmark file (e.g. to add `reduce_scatter`, copy `all_reduce.py` as a template)
|
||||
2. Add a new bw formula in `utils.get_bw`
|
||||
3. Add a new maximum tensor element formula in `utils.max_numel`
|
||||
4. Replace comm op calls in new file with find-replace
|
||||
5. Find a good default `mem_factor` for use in `run_<collective>_single()` function
|
||||
6. Add new comm op to `run_all.py`
|
|
@ -0,0 +1,153 @@
|
|||
import torch
|
||||
from benchmarks.communication.utils import *
|
||||
from benchmarks.communication.constants import *
|
||||
|
||||
import time
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import math
|
||||
|
||||
|
||||
# Run allgather and print metrics
|
||||
def timed_allgather(input, output, args):
|
||||
if args.dist == 'torch':
|
||||
import torch.distributed as dist
|
||||
elif args.dist == 'deepspeed':
|
||||
import deepspeed.comm as dist
|
||||
|
||||
sync_all()
|
||||
# Warmup, establish connections, etc.
|
||||
for i in range(args.warmup):
|
||||
# use all_gather_base if available
|
||||
if args.dist == 'torch':
|
||||
if hasattr(torch.distributed, "_all_gather_base"):
|
||||
dist._all_gather_base(output, input, group=None, async_op=args.async_op)
|
||||
else:
|
||||
output_tensors = list(
|
||||
torch.chunk(output_tensor,
|
||||
cdb.get_world_size(group)))
|
||||
dist.all_gather(output_tensors, input_tensor, group=group, async_op=True)
|
||||
elif args.dist == 'deepspeed':
|
||||
dist.allgather_fn(output, input, group=None, async_op=args.async_op)
|
||||
sync_all()
|
||||
|
||||
# time the actual comm op trials times and average it
|
||||
pre = time.perf_counter()
|
||||
for i in range(args.trials):
|
||||
# use all_gather_base if available
|
||||
if args.dist == 'torch':
|
||||
if hasattr(torch.distributed, "_all_gather_base"):
|
||||
dist._all_gather_base(output, input, group=None, async_op=args.async_op)
|
||||
else:
|
||||
output_tensors = list(
|
||||
torch.chunk(output_tensor,
|
||||
cdb.get_world_size(group)))
|
||||
dist.all_gather(output_tensors, input_tensor, group=group, async_op=True)
|
||||
elif args.dist == 'deepspeed':
|
||||
dist.allgather_fn(output, input, group=None, async_op=args.async_op)
|
||||
sync_all()
|
||||
duration = time.perf_counter() - pre
|
||||
|
||||
# maintain and clean performance data
|
||||
avg_duration = duration / args.trials
|
||||
size = input.element_size() * input.nelement()
|
||||
n = dist.get_world_size()
|
||||
tput, busbw = get_bw('allgather', size, avg_duration, args)
|
||||
tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
|
||||
desc = f'{input.nelement()}x{input.element_size()}'
|
||||
|
||||
print_rank_0(
|
||||
f"{convert_size(size):<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}"
|
||||
)
|
||||
|
||||
|
||||
def run_allgather(local_rank, args):
|
||||
if args.dist == 'torch':
|
||||
import torch.distributed as dist
|
||||
elif args.dist == 'deepspeed':
|
||||
import deepspeed.comm as dist
|
||||
|
||||
# Prepare benchmark header
|
||||
print_header(args, 'allgather')
|
||||
global_rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
if args.scan:
|
||||
# Create list of message sizes
|
||||
M_LIST = []
|
||||
for x in (2**p for p in range(1, args.maxsize)):
|
||||
M_LIST.append(x)
|
||||
|
||||
sync_all()
|
||||
# loop over various tensor sizes
|
||||
for M in M_LIST:
|
||||
global_rank = dist.get_rank()
|
||||
try:
|
||||
mat = torch.ones(world_size,
|
||||
M,
|
||||
dtype=getattr(torch,
|
||||
args.dtype)).cuda(local_rank)
|
||||
sync_all()
|
||||
input = ((mat.mul_(float(global_rank))).view(-1))
|
||||
# Delete original mat to avoid OOM
|
||||
del mat
|
||||
torch.cuda.empty_cache()
|
||||
output = torch.zeros(input.nelement() * world_size,
|
||||
dtype=getattr(torch,
|
||||
args.dtype)).cuda(local_rank)
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
if dist.get_rank() == 0:
|
||||
print('WARNING: Ran out of GPU memory. Exiting comm op.')
|
||||
sync_all()
|
||||
break
|
||||
sync_all()
|
||||
timed_allgather(input, output, args)
|
||||
else:
|
||||
# all_gather_base saves memory
|
||||
if (args.dist == 'torch'
|
||||
and hasattr(torch.distributed,
|
||||
"_all_gather_base")) or (args.dist == 'deepspeed'
|
||||
and dist.has_allgather_base):
|
||||
mem_factor = args.mem_factor + 0.2
|
||||
else:
|
||||
mem_factor = args.mem_factor
|
||||
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
|
||||
sync_all()
|
||||
elements_per_gpu = max_numel(comm_op='allgather',
|
||||
dtype=getattr(torch,
|
||||
args.dtype),
|
||||
mem_factor=mem_factor,
|
||||
local_rank=local_rank,
|
||||
args=args)
|
||||
try:
|
||||
mat = torch.ones(elements_per_gpu,
|
||||
dtype=getattr(torch,
|
||||
args.dtype)).cuda(local_rank)
|
||||
# multiply each GPU's tensor by the rank to ease debugging
|
||||
input = ((mat.mul_(float(global_rank))).view(-1))
|
||||
# Delete original mat to avoid OOM
|
||||
del mat
|
||||
torch.cuda.empty_cache()
|
||||
output = torch.zeros(elements_per_gpu * world_size,
|
||||
dtype=getattr(torch,
|
||||
args.dtype)).cuda(local_rank)
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
if dist.get_rank() == 0:
|
||||
print(
|
||||
'WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!'
|
||||
)
|
||||
sync_all()
|
||||
return
|
||||
|
||||
sync_all()
|
||||
timed_allgather(input, output, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = benchmark_parser().parse_args()
|
||||
rank = args.local_rank
|
||||
init_processes(local_rank=rank, args=args)
|
||||
run_allgather(local_rank=rank, args=args)
|
|
@ -0,0 +1,109 @@
|
|||
import torch
|
||||
from benchmarks.communication.utils import *
|
||||
from benchmarks.communication.constants import *
|
||||
|
||||
import time
|
||||
import argparse
|
||||
import os
|
||||
import math
|
||||
|
||||
|
||||
def timed_allreduce(input, args):
|
||||
if args.dist == 'torch':
|
||||
import torch.distributed as dist
|
||||
elif args.dist == 'deepspeed':
|
||||
import deepspeed.comm as dist
|
||||
|
||||
sync_all()
|
||||
# Warmup, establish connections, etc.
|
||||
for i in range(args.warmup):
|
||||
dist.all_reduce(input, async_op=args.async_op)
|
||||
sync_all()
|
||||
|
||||
# time the actual comm op trials times and average it
|
||||
pre = time.perf_counter()
|
||||
for i in range(args.trials):
|
||||
dist.all_reduce(input, async_op=args.async_op)
|
||||
sync_all()
|
||||
duration = time.perf_counter() - pre
|
||||
|
||||
# maintain and clean performance data
|
||||
avg_duration = duration / args.trials
|
||||
size = input.element_size() * input.nelement()
|
||||
n = dist.get_world_size()
|
||||
tput, busbw = get_bw('allreduce', size, avg_duration, args)
|
||||
tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
|
||||
desc = f'{input.nelement()}x{input.element_size()}'
|
||||
|
||||
print_rank_0(
|
||||
f"{convert_size(size):<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}"
|
||||
)
|
||||
|
||||
|
||||
def run_allreduce(local_rank, args):
|
||||
if args.dist == 'torch':
|
||||
import torch.distributed as dist
|
||||
elif args.dist == 'deepspeed':
|
||||
import deepspeed.comm as dist
|
||||
|
||||
# Prepare benchmark header
|
||||
print_header(args, 'allreduce')
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
global_rank = dist.get_rank()
|
||||
|
||||
if args.scan:
|
||||
M_LIST = []
|
||||
for x in (2**p for p in range(1, args.maxsize)):
|
||||
M_LIST.append(x)
|
||||
|
||||
sync_all()
|
||||
# loop over various tensor sizes
|
||||
for M in M_LIST:
|
||||
global_rank = dist.get_rank()
|
||||
try:
|
||||
mat = torch.ones(world_size,
|
||||
M,
|
||||
dtype=getattr(torch,
|
||||
args.dtype)).cuda(local_rank)
|
||||
sync_all()
|
||||
input = ((mat.mul_(float(global_rank))).view(-1))
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
if dist.get_rank() == 0:
|
||||
print('WARNING: Ran out of GPU memory. Exiting comm op.')
|
||||
sync_all()
|
||||
break
|
||||
sync_all()
|
||||
timed_allreduce(input, args)
|
||||
else:
|
||||
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
|
||||
# Don't need output tensor, so we double mem_factor
|
||||
elements_per_gpu = max_numel(comm_op='allreduce',
|
||||
dtype=getattr(torch,
|
||||
args.dtype),
|
||||
mem_factor=args.mem_factor * 2,
|
||||
local_rank=local_rank,
|
||||
args=args)
|
||||
try:
|
||||
mat = torch.ones(elements_per_gpu,
|
||||
dtype=getattr(torch,
|
||||
args.dtype)).cuda(local_rank)
|
||||
input = ((mat.mul_(float(global_rank))).view(-1))
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
if dist.get_rank() == 0:
|
||||
print(
|
||||
'WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!'
|
||||
)
|
||||
sync_all()
|
||||
return
|
||||
sync_all()
|
||||
timed_allreduce(input, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = benchmark_parser().parse_args()
|
||||
rank = args.local_rank
|
||||
init_processes(local_rank=rank, args=args)
|
||||
run_allreduce(local_rank=rank, args=args)
|
|
@ -0,0 +1,129 @@
|
|||
import torch
|
||||
from benchmarks.communication.utils import *
|
||||
from benchmarks.communication.constants import *
|
||||
|
||||
import time
|
||||
import argparse
|
||||
import os
|
||||
import math
|
||||
|
||||
|
||||
def timed_alltoall(input, output, args):
|
||||
if args.dist == 'torch':
|
||||
import torch.distributed as dist
|
||||
elif args.dist == 'deepspeed':
|
||||
import deepspeed.comm as dist
|
||||
|
||||
sync_all()
|
||||
# Warmup, establish connections, etc.
|
||||
for i in range(args.warmup):
|
||||
dist.all_to_all_single(output, input, async_op=args.async_op)
|
||||
sync_all()
|
||||
|
||||
# time the actual comm op trials times and average it
|
||||
pre = time.perf_counter()
|
||||
for i in range(args.trials):
|
||||
dist.all_to_all_single(output, input, async_op=args.async_op)
|
||||
sync_all()
|
||||
duration = time.perf_counter() - pre
|
||||
|
||||
# maintain and clean performance data
|
||||
avg_duration = duration / args.trials
|
||||
size = input.element_size() * input.nelement()
|
||||
n = dist.get_world_size()
|
||||
tput, busbw = get_bw('alltoall', size, avg_duration, args)
|
||||
tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
|
||||
desc = f'{input.nelement()}x{input.element_size()}'
|
||||
|
||||
print_rank_0(
|
||||
f"{convert_size(size):<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}"
|
||||
)
|
||||
|
||||
|
||||
def run_alltoall(local_rank, args):
|
||||
if args.dist == 'torch':
|
||||
import torch.distributed as dist
|
||||
elif args.dist == 'deepspeed':
|
||||
import deepspeed.comm as dist
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
global_rank = dist.get_rank()
|
||||
# Prepare benchmark header
|
||||
print_header(args, 'alltoall')
|
||||
|
||||
if args.scan:
|
||||
M_LIST = []
|
||||
for x in (2**p for p in range(1, args.maxsize)):
|
||||
M_LIST.append(x)
|
||||
|
||||
sync_all()
|
||||
# loop over various tensor sizes
|
||||
for M in M_LIST:
|
||||
global_rank = dist.get_rank()
|
||||
try:
|
||||
mat = torch.ones(world_size,
|
||||
M,
|
||||
dtype=getattr(torch,
|
||||
args.dtype)).cuda(local_rank)
|
||||
assert mat.numel() % world_size == 0, f"tensor cannot be divided in {world_size} chunks"
|
||||
sync_all()
|
||||
input = ((mat.mul_(float(global_rank))).view(-1))
|
||||
output = (mat.clone().view(-1))
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
if dist.get_rank() == 0:
|
||||
print('WARNING: Ran out of GPU memory. Exiting comm op.')
|
||||
sync_all()
|
||||
break
|
||||
sync_all()
|
||||
timed_alltoall(input, output, args)
|
||||
else:
|
||||
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
|
||||
elements_per_gpu = max_numel(comm_op='alltoall',
|
||||
dtype=getattr(torch,
|
||||
args.dtype),
|
||||
mem_factor=args.mem_factor,
|
||||
local_rank=local_rank,
|
||||
args=args)
|
||||
try:
|
||||
mat = torch.ones(elements_per_gpu,
|
||||
dtype=getattr(torch,
|
||||
args.dtype)).cuda(local_rank)
|
||||
assert mat.numel() % world_size == 0, f"tensor with {mat.numel()} elements cannot be divided in {world_size} chunks"
|
||||
input = ((mat.mul_(float(global_rank))).view(-1))
|
||||
# Delete original mat to avoid OOM
|
||||
del mat
|
||||
torch.cuda.empty_cache()
|
||||
output = torch.zeros(elements_per_gpu,
|
||||
dtype=getattr(torch,
|
||||
args.dtype)).cuda(local_rank)
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
if dist.get_rank() == 0:
|
||||
print(
|
||||
'WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!'
|
||||
)
|
||||
sync_all()
|
||||
return
|
||||
sync_all()
|
||||
|
||||
if args.debug:
|
||||
for i in range(world_size):
|
||||
if i == global_rank:
|
||||
print(f"Before AllToAll Input List at rank {global_rank}: {input}")
|
||||
dist.barrier()
|
||||
|
||||
timed_alltoall(input, output, args)
|
||||
|
||||
if args.debug:
|
||||
for i in range(world_size):
|
||||
if i == global_rank:
|
||||
print(f"AllToAll Results at rank {global_rank}: {output}")
|
||||
dist.barrier()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = benchmark_parser().parse_args()
|
||||
rank = args.local_rank
|
||||
init_processes(local_rank=rank, args=args)
|
||||
run_alltoall(local_rank=rank, args=args)
|
|
@ -0,0 +1,9 @@
|
|||
import torch
|
||||
|
||||
DEFAULT_WARMUPS = 5
|
||||
DEFAULT_TRIALS = 50
|
||||
DEFAULT_TYPE = 'float'
|
||||
DEFAULT_BACKEND = 'nccl'
|
||||
DEFAULT_UNIT = 'Gbps'
|
||||
DEFAULT_DIST = 'deepspeed'
|
||||
DEFAULT_MAXSIZE = 24
|
|
@ -0,0 +1,128 @@
|
|||
import torch
|
||||
from benchmarks.communication.utils import *
|
||||
from benchmarks.communication.constants import *
|
||||
|
||||
import time
|
||||
import argparse
|
||||
import os
|
||||
import math
|
||||
|
||||
|
||||
def timed_pt2pt(input, args):
|
||||
if args.dist == 'torch':
|
||||
import torch.distributed as dist
|
||||
elif args.dist == 'deepspeed':
|
||||
import deepspeed.comm as dist
|
||||
|
||||
sync_all()
|
||||
# Warmup, establish connections, etc.
|
||||
for i in range(args.warmup):
|
||||
if dist.get_rank() == 0:
|
||||
if args.async_op:
|
||||
dist.isend(input, 1)
|
||||
else:
|
||||
dist.send(input, 1)
|
||||
if dist.get_rank() == 1:
|
||||
if args.async_op:
|
||||
dist.irecv(input, src=0)
|
||||
else:
|
||||
dist.recv(input, src=0)
|
||||
sync_all()
|
||||
|
||||
# time the actual comm op trials times and average it
|
||||
pre = time.perf_counter()
|
||||
for i in range(args.trials):
|
||||
if dist.get_rank() == 0:
|
||||
if args.async_op:
|
||||
dist.isend(input, 1)
|
||||
else:
|
||||
dist.send(input, 1)
|
||||
if dist.get_rank() == 1:
|
||||
if args.async_op:
|
||||
dist.irecv(input, src=0)
|
||||
else:
|
||||
dist.recv(input, src=0)
|
||||
|
||||
sync_all()
|
||||
duration = time.perf_counter() - pre
|
||||
|
||||
# maintain and clean performance data
|
||||
avg_duration = duration / args.trials
|
||||
size = input.element_size() * input.nelement()
|
||||
n = dist.get_world_size()
|
||||
tput, busbw = get_bw('pt2pt', size, avg_duration, args)
|
||||
tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
|
||||
desc = f'{input.nelement()}x{input.element_size()}'
|
||||
|
||||
print_rank_0(
|
||||
f"{convert_size(size):<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}"
|
||||
)
|
||||
|
||||
|
||||
def run_pt2pt(local_rank, args):
|
||||
if args.dist == 'torch':
|
||||
import torch.distributed as dist
|
||||
elif args.dist == 'deepspeed':
|
||||
import deepspeed.comm as dist
|
||||
|
||||
# Prepare benchmark header
|
||||
print_header(args, 'pt2pt')
|
||||
global_rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
if args.scan:
|
||||
# Create list of message sizes
|
||||
M_LIST = []
|
||||
for x in (2**p for p in range(1, args.maxsize)):
|
||||
M_LIST.append(x)
|
||||
|
||||
sync_all()
|
||||
# loop over various tensor sizes
|
||||
for M in M_LIST:
|
||||
global_rank = dist.get_rank()
|
||||
try:
|
||||
mat = torch.ones(world_size,
|
||||
M,
|
||||
dtype=getattr(torch,
|
||||
args.dtype)).cuda(local_rank)
|
||||
sync_all()
|
||||
input = ((mat.mul_(float(global_rank))).view(-1))
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
if dist.get_rank() == 0:
|
||||
print('WARNING: Ran out of GPU memory. Exiting comm op.')
|
||||
sync_all()
|
||||
break
|
||||
sync_all()
|
||||
timed_pt2pt(input, args)
|
||||
else:
|
||||
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
|
||||
# Don't need output tensor, so double mem_factor
|
||||
elements_per_gpu = max_numel(comm_op='pt2pt',
|
||||
dtype=getattr(torch,
|
||||
args.dtype),
|
||||
mem_factor=args.mem_factor * 2,
|
||||
local_rank=local_rank,
|
||||
args=args)
|
||||
try:
|
||||
mat = torch.ones(elements_per_gpu,
|
||||
dtype=getattr(torch,
|
||||
args.dtype)).cuda(local_rank)
|
||||
input = ((mat.mul_(float(global_rank))).view(-1))
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
if dist.get_rank() == 0:
|
||||
print(
|
||||
'WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!'
|
||||
)
|
||||
sync_all()
|
||||
return
|
||||
sync_all()
|
||||
timed_pt2pt(input, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = benchmark_parser().parse_args()
|
||||
rank = args.local_rank
|
||||
init_processes(local_rank=rank, args=args)
|
||||
run_pt2pt(local_rank=rank, args=args)
|
|
@ -0,0 +1,34 @@
|
|||
import torch
|
||||
from benchmarks.communication.utils import *
|
||||
from benchmarks.communication.all_reduce import run_allreduce
|
||||
from benchmarks.communication.all_gather import run_allgather
|
||||
from benchmarks.communication.all_to_all import run_alltoall
|
||||
from benchmarks.communication.pt2pt import run_pt2pt
|
||||
from benchmarks.communication.constants import *
|
||||
|
||||
import time
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
# For importing
|
||||
def main(args, rank):
|
||||
|
||||
init_processes(local_rank=rank, args=args)
|
||||
|
||||
for comm_op in ['allreduce', 'alltoall', 'allgather', 'pt2pt']:
|
||||
if comm_op == 'allreduce':
|
||||
run_allreduce(local_rank=rank, args=args)
|
||||
if comm_op == 'allgather':
|
||||
run_allgather(local_rank=rank, args=args)
|
||||
if comm_op == 'alltoall':
|
||||
run_alltoall(local_rank=rank, args=args)
|
||||
if comm_op == 'pt2pt':
|
||||
run_pt2pt(local_rank=rank, args=args)
|
||||
|
||||
|
||||
# For directly calling benchmark
|
||||
if __name__ == "__main__":
|
||||
args = benchmark_parser().parse_args()
|
||||
rank = args.local_rank
|
||||
main(args, rank)
|
|
@ -0,0 +1,185 @@
|
|||
import torch
|
||||
import os
|
||||
import math
|
||||
import argparse
|
||||
from benchmarks.communication.constants import *
|
||||
|
||||
global dist
|
||||
|
||||
|
||||
def init_torch_distributed(backend):
|
||||
global dist
|
||||
import torch.distributed as dist
|
||||
torch.distributed.init_process_group(backend)
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
|
||||
def init_deepspeed_comm(backend):
|
||||
global dist
|
||||
import deepspeed
|
||||
import deepspeed.comm as dist
|
||||
deepspeed.init_distributed(dist_backend=backend)
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
|
||||
def init_processes(local_rank, args):
|
||||
if args.dist == 'deepspeed':
|
||||
init_deepspeed_comm(args.backend)
|
||||
elif args.dist == 'torch':
|
||||
init_torch_distributed(args.backend)
|
||||
else:
|
||||
print_rank_0(f"distributed framework {args.dist} not supported")
|
||||
exit(0)
|
||||
|
||||
|
||||
def print_rank_0(message):
|
||||
if dist.get_rank() == 0:
|
||||
print(message)
|
||||
|
||||
|
||||
def print_header(args, comm_op):
|
||||
if comm_op == 'pt2pt':
|
||||
world_size = 2
|
||||
else:
|
||||
world_size = dist.get_world_size()
|
||||
tput = f'Throughput ({args.bw_unit})'
|
||||
busbw = f'BusBW ({args.bw_unit})'
|
||||
header = f"\n---- Performance of {comm_op} on {world_size} devices ---------------------------------------------------------\n"
|
||||
header += f"{'Size (Bytes)':20s} {'Description':25s} {'Duration':20s} {tput:20s} {busbw:20s}\n"
|
||||
header += "----------------------------------------------------------------------------------------------------"
|
||||
print_rank_0(header)
|
||||
|
||||
|
||||
def get_bw(comm_op, size, duration, args):
|
||||
n = dist.get_world_size()
|
||||
tput = 0
|
||||
busbw = 0
|
||||
if comm_op == "alltoall":
|
||||
tput = (size / duration)
|
||||
busbw = (size / duration) * ((n - 1) / n)
|
||||
elif comm_op == "allgather":
|
||||
size *= n
|
||||
tput = (size / duration)
|
||||
busbw = (size / duration) * ((n - 1) / n)
|
||||
elif comm_op == "allreduce":
|
||||
tput = (size * 2 / duration)
|
||||
busbw = (size / duration) * (2 * (n - 1) / n)
|
||||
elif comm_op == "pt2pt":
|
||||
tput = (size / duration)
|
||||
busbw = tput
|
||||
else:
|
||||
print_rank_0("wrong comm_op specified")
|
||||
exit(0)
|
||||
|
||||
if args.bw_unit == 'Gbps':
|
||||
tput *= 8
|
||||
busbw *= 8
|
||||
|
||||
return tput, busbw
|
||||
|
||||
|
||||
def get_metric_strings(args, tput, busbw, duration):
|
||||
duration_ms = duration * 1e3
|
||||
duration_us = duration * 1e6
|
||||
tput = f'{tput / 1e9:.3f}'
|
||||
busbw = f'{busbw /1e9:.3f}'
|
||||
|
||||
if duration_us < 1e3:
|
||||
duration = f'{duration_us:.3f} us'
|
||||
else:
|
||||
duration = f'{duration_ms:.3f} ms'
|
||||
return tput, busbw, duration
|
||||
|
||||
|
||||
def sync_all():
|
||||
torch.cuda.synchronize()
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def max_numel(comm_op, dtype, mem_factor, local_rank, args):
|
||||
dtype_size = torch._utils._element_size(dtype)
|
||||
max_memory_per_gpu = torch.cuda.get_device_properties(
|
||||
local_rank).total_memory * mem_factor
|
||||
if comm_op == 'allreduce' or comm_op == 'pt2pt':
|
||||
elements_per_gpu = int(max_memory_per_gpu // dtype_size)
|
||||
elif comm_op == 'allgather':
|
||||
# all_gather performance is lower for non-powers of two, and the output buffer size scales with world size
|
||||
# Therefore, divide by world size and round down to nearest power of 2
|
||||
elements_per_gpu = int(max_memory_per_gpu // dtype_size // dist.get_world_size())
|
||||
elements_per_gpu = int(pow(2, int(math.log(elements_per_gpu, 2))))
|
||||
elif comm_op == 'alltoall':
|
||||
# Number of elements must be divisible by world_size
|
||||
# all_to_all performance is lower for non-powers of two. Round down like allgather.
|
||||
elements_per_gpu = int(max_memory_per_gpu // dtype_size)
|
||||
elements_per_gpu = int(dist.get_world_size() *
|
||||
round(elements_per_gpu / dist.get_world_size()))
|
||||
elements_per_gpu = int(pow(2, int(math.log(elements_per_gpu, 2))))
|
||||
else:
|
||||
print(f"This communication operation: {comm_op} is not supported yet")
|
||||
exit(0)
|
||||
return elements_per_gpu
|
||||
|
||||
|
||||
# Helper function to pretty-print message sizes
|
||||
def convert_size(size_bytes):
|
||||
if size_bytes == 0:
|
||||
return "0B"
|
||||
size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
|
||||
i = int(math.floor(math.log(size_bytes, 1024)))
|
||||
p = math.pow(1024, i)
|
||||
s = round(size_bytes / p, 2)
|
||||
return "%s %s" % (s, size_name[i])
|
||||
|
||||
|
||||
def benchmark_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--local_rank", type=int)
|
||||
parser.add_argument("--trials",
|
||||
type=int,
|
||||
default=DEFAULT_TRIALS,
|
||||
help='Number of timed iterations')
|
||||
parser.add_argument("--warmup",
|
||||
type=int,
|
||||
default=DEFAULT_WARMUPS,
|
||||
help='Number of warmup (non-timed) iterations')
|
||||
parser.add_argument("--maxsize",
|
||||
type=int,
|
||||
default=24,
|
||||
help='Max message size as a power of 2')
|
||||
parser.add_argument("--async-op",
|
||||
action="store_true",
|
||||
help='Enables non-blocking communication')
|
||||
parser.add_argument("--bw-unit",
|
||||
type=str,
|
||||
default=DEFAULT_UNIT,
|
||||
choices=['Gbps',
|
||||
'GBps'])
|
||||
parser.add_argument("--backend",
|
||||
type=str,
|
||||
default=DEFAULT_BACKEND,
|
||||
choices=['nccl'],
|
||||
help='Communication library to use')
|
||||
parser.add_argument("--dist",
|
||||
type=str,
|
||||
default=DEFAULT_DIST,
|
||||
choices=['deepspeed',
|
||||
'torch'],
|
||||
help='Distributed DL framework to use')
|
||||
parser.add_argument("--scan",
|
||||
action="store_true",
|
||||
help='Enables scanning all message sizes')
|
||||
parser.add_argument("--dtype",
|
||||
type=str,
|
||||
default=DEFAULT_TYPE,
|
||||
help='PyTorch tensor dtype')
|
||||
parser.add_argument(
|
||||
"--mem-factor",
|
||||
type=float,
|
||||
default=.4,
|
||||
help='Proportion of max available GPU memory to use for single-size evals')
|
||||
parser.add_argument("--debug",
|
||||
action="store_true",
|
||||
help='Enables alltoall debug prints')
|
||||
return parser
|
|
@ -0,0 +1,18 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from benchmarks.communication.run_all import main
|
||||
from benchmarks.communication.constants import *
|
||||
from benchmarks.communication.utils import *
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Run the same file with deepspeed launcher. This is required since setuptools will auto-detect python files and insert a python shebang for both 'scripts' and 'entry_points', and this benchmarks require the DS launcher
|
||||
required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
|
||||
if not all(map(lambda v: v in os.environ, required_env)):
|
||||
import subprocess
|
||||
subprocess.run("deepspeed $(which ds_bench) " + " ".join(sys.argv[1:]), shell=True)
|
||||
else:
|
||||
args = benchmark_parser().parse_args()
|
||||
rank = args.local_rank
|
||||
main(args, rank)
|
|
@ -206,7 +206,7 @@ def allgather_fn(output_tensor: torch.Tensor,
|
|||
group=group,
|
||||
async_op=True)
|
||||
else:
|
||||
if not has_warned_all_gather:
|
||||
if not has_warned_all_gather and get_rank() == 0:
|
||||
utils.logger.warning(
|
||||
"unable to find torch.distributed._all_gather_base. will fall back to "
|
||||
"torch.distributed.all_gather which will result in suboptimal performance. "
|
||||
|
|
1
setup.py
1
setup.py
|
@ -300,6 +300,7 @@ setup(name='deepspeed',
|
|||
'bin/ds',
|
||||
'bin/ds_ssh',
|
||||
'bin/ds_report',
|
||||
'bin/ds_bench',
|
||||
'bin/dsr',
|
||||
'bin/ds_elastic'
|
||||
],
|
||||
|
|
Загрузка…
Ссылка в новой задаче