Ability to initialize distributed backend outside deepspeed runtime (#608)

This commit is contained in:
Jeff Rasley 2020-12-17 23:17:19 -08:00 коммит произвёл GitHub
Родитель fd2f970bdf
Коммит 7435b2f10a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
13 изменённых файлов: 175 добавлений и 136 удалений

@ -1 +1 @@
Subproject commit abb270641ca8c33476282bde29916c395a060ae9
Subproject commit 78d69cb2f89a27b1e9b072df8c3e47d00c024fdc

Просмотреть файл

@ -14,6 +14,7 @@ from .runtime.config import DeepSpeedConfig
from .runtime.activation_checkpointing import checkpointing
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .utils import log_dist
from .utils.distributed import init_distributed
from .pipe import PipelineModule

8
deepspeed/constants.py Normal file
Просмотреть файл

@ -0,0 +1,8 @@
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
#############################################
# Torch distributed constants
#############################################
TORCH_DISTRIBUTED_DEFAULT_PORT = 29500

Просмотреть файл

@ -1,10 +1,5 @@
# Copyright 2020 The Microsoft DeepSpeed Team
#############################################
# Torch distributed constants
#############################################
TORCH_DISTRIBUTED_DEFAULT_PORT = 29500
PDSH_LAUNCHER = 'pdsh'
PDSH_MAX_FAN_OUT = 1024

Просмотреть файл

@ -16,7 +16,7 @@ import base64
from collections import defaultdict
from argparse import ArgumentParser, REMAINDER
from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from ..utils import logger

Просмотреть файл

@ -19,8 +19,8 @@ from copy import deepcopy
import torch.cuda
from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner
from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT, \
PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER
from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from ..utils import logger
DLTS_HOSTFILE = "/job/hostfile"

Просмотреть файл

@ -73,11 +73,6 @@ MAX_GRAD_NORM = 'max_grad_norm'
ZERO_ALLOW_UNTESTED_OPTIMIZER = "zero_allow_untested_optimizer"
ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT = False
#############################################
# Torch distributed constants
#############################################
TORCH_DISTRIBUTED_DEFAULT_PORT = "29500"
# Steps
STEPS_PER_PRINT = "steps_per_print"
STEPS_PER_PRINT_DEFAULT = 10

Просмотреть файл

@ -24,12 +24,12 @@ from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \
from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
TORCH_DISTRIBUTED_DEFAULT_PORT, PLD_THETA, PLD_GAMMA
PLD_THETA, PLD_GAMMA
from deepspeed.runtime.zero.constants import \
ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS
from deepspeed.runtime.csr_tensor import CSRTensor
import deepspeed.runtime.lr_schedules as lr_schedules
from deepspeed.utils import logger, log_dist
from deepspeed.utils import logger, log_dist, init_distributed
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop
@ -130,29 +130,14 @@ class DeepSpeedEngine(Module):
if dist_init_required is False:
assert (dist.is_initialized()==True), "Torch distributed not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()"
# DeepSpeed will initialize torch distributed only if the user has not already intialized it.
if dist_init_required and not dist.is_initialized():
# discover using mpi4py if user specifies the flag
if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi:
# if in Azure ML environment and user specified this flag, notify the user to remove the flag.
if self._in_aml():
logger.warning(
"Please remove the --deepspeed_mpi flag if running on AzureML.")
self._mpi_check(args, dist_init_required)
else:
# detect if we are in Azure ML environment
if self._in_aml():
self._set_environment_variables_for_nccl_backend(args)
logger.info("Initializing torch distributed with backend: {}".format(
self.dist_backend))
dist.init_process_group(backend=self.dist_backend)
# Initialize torch distributed if needed
init_distributed(dist_backend=self.dist_backend)
self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu)
self._do_sanity_check()
self._init_distributed(dist_init_required)
self._set_distributed_vars()
if self.tensorboard_enabled() and self.global_rank == 0:
self.summary_writer = self.get_summary_writer()
@ -209,87 +194,6 @@ class DeepSpeedEngine(Module):
self.flatten = util_ops.flatten
self.unflatten = util_ops.unflatten
def _in_aml(self):
# read AzureML environment variable to detect if we are using an Azure ML environment
if 'AZUREML_EXPERIMENT_ID' in os.environ:
return True
else:
return False
def _set_environment_variables_for_nccl_backend(self,
args,
master_port=6105,
verbose=True):
"""Helper routine to get and set environment variables.
This is adapted from Azure ML's documentation available from:
https://azure.github.io/azureml-web/docs/cheatsheet/distributed-training/#environment-variables-from-openmpi
"""
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int(
os.environ["WORLD_SIZE"])
if not single_node:
master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":")
os.environ["MASTER_ADDR"] = master_node_params[0]
# Do not overwrite master port with that defined in AZ_BATCH_MASTER_NODE
if "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = str(master_port)
else:
os.environ["MASTER_ADDR"] = os.environ["AZ_BATCHAI_MPI_MASTER_NODE"]
os.environ["MASTER_PORT"] = "54965"
print("NCCL_SOCKET_IFNAME original value = {}".format(
os.environ["NCCL_SOCKET_IFNAME"]))
os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo"
args.local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
if verbose:
logger.info(
"Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
args.local_rank,
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
def _mpi_check(self, args, dist_init_required):
from mpi4py import MPI
import subprocess
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
world_size = comm.Get_size()
master_addr = None
if rank == 0:
hostname_cmd = ["hostname -I"]
result = subprocess.check_output(hostname_cmd, shell=True)
master_addr = result.decode('utf-8').split()[0]
master_addr = comm.bcast(master_addr, root=0)
# Determine local rank by assuming hostnames are unique
proc_name = MPI.Get_processor_name()
all_procs = comm.allgather(proc_name)
local_rank = sum([i == proc_name for i in all_procs[:rank]])
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
args.local_rank = local_rank
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = TORCH_DISTRIBUTED_DEFAULT_PORT
logger.info(
"Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
args.local_rank,
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
if not dist_init_required and dist.is_initialized():
assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank())
assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
world_size, dist.get_world_size())
def pld_enabled(self):
return self._config.pld_enabled
@ -497,7 +401,7 @@ class DeepSpeedEngine(Module):
else:
return None
def _init_distributed(self, dist_init_required):
def _set_distributed_vars(self):
if self.local_rank >= 0:
torch.cuda.set_device(self.local_rank)
self.device = torch.device("cuda", self.local_rank)

Просмотреть файл

@ -1,2 +1,3 @@
from deepspeed.utils.logging import logger, log_dist
from .logging import logger, log_dist
from .distributed import init_distributed
from deepspeed.runtime.dataloader import RepeatingLoader

Просмотреть файл

@ -0,0 +1,129 @@
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
import os
import torch
from .logging import logger
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
def init_distributed(dist_backend="nccl",
auto_mpi_discovery=True,
distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT,
verbose=True):
"""
Initialize torch.distributed backend, potentially performing MPI discovery if needed
Arguments:
dist_backend (str): torch distributed backend, e.g., nccl, mpi, gloo
auto_mpi_discovery (bool): if distributed environment variables are not set, attempt to discover them from MPI
distributed_port (int, optional): torch distributed backend port
verbose (bool, optional): verbose logging
"""
required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)):
if verbose:
logger.info(
"Not using the DeepSpeed or torch.distributed launchers, attempting to detect MPI environment..."
)
if in_aml() and not in_dlts():
patch_aml_env_for_torch_nccl_backend(verbose=verbose)
else:
mpi_discovery(distributed_port=distributed_port, verbose=verbose)
if not torch.distributed.is_initialized():
if verbose:
logger.info(
"Initializing torch distributed with backend: {}".format(dist_backend))
torch.distributed.init_process_group(backend=dist_backend)
def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True):
"""
Discovery MPI environment via mpi4py and map to relevant torch.distributed state
"""
from mpi4py import MPI
import subprocess
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
world_size = comm.Get_size()
master_addr = None
if rank == 0:
hostname_cmd = ["hostname -I"]
result = subprocess.check_output(hostname_cmd, shell=True)
master_addr = result.decode('utf-8').split()[0]
master_addr = comm.bcast(master_addr, root=0)
# Determine local rank by assuming hostnames are unique
proc_name = MPI.Get_processor_name()
all_procs = comm.allgather(proc_name)
local_rank = sum([i == proc_name for i in all_procs[:rank]])
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['LOCAL_RANK'] = str(local_rank)
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = str(distributed_port)
if verbose:
logger.info(
"Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
os.environ['LOCAL_RANK'],
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
if torch.distributed.is_initialized():
assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank())
assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
world_size, dist.get_world_size())
def in_aml():
# Are we running inside an Azure Machine Learning (AML) environment?
return 'AZUREML_EXPERIMENT_ID' in os.environ
def in_dlts():
# Are we running on a DLTS cluster?
return 'DLTS_JOB_ID' in os.environ
def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
"""Helper routine to get and set environment variables.
This is adapted from Azure ML's documentation available from:
https://azure.github.io/azureml-web/docs/cheatsheet/distributed-training/#environment-variables-from-openmpi
"""
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int(
os.environ["WORLD_SIZE"])
if not single_node:
master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":")
os.environ["MASTER_ADDR"] = master_node_params[0]
# Do not overwrite master port with that defined in AZ_BATCH_MASTER_NODE
if "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = str(master_port)
else:
os.environ["MASTER_ADDR"] = os.environ["AZ_BATCHAI_MPI_MASTER_NODE"]
os.environ["MASTER_PORT"] = "54965"
if verbose:
logger.info("NCCL_SOCKET_IFNAME original value = {}".format(
os.environ["NCCL_SOCKET_IFNAME"]))
os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo"
os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
if verbose:
logger.info(
"Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
os.environ['LOCAL_RANK'],
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))

Просмотреть файл

@ -216,25 +216,27 @@ DeepSpeed will then make sure that these environment variables are set when
launching each process on every node across their training job.
### MPI Compatibility
### MPI and AzureML Compatibility
As described above, DeepSpeed provides its own parallel launcher to help launch
multi-node/multi-gpu training jobs. If you prefer to launch your training job
using MPI (e.g., mpirun), we provide support for this. It should be noted that
DeepSpeed will still use the torch distributed NCCL backend and *not* the MPI
backend. To launch your training job with mpirun + DeepSpeed you simply pass us
an additional flag `--deepspeed_mpi`. DeepSpeed will then use
[mpi4py](https://pypi.org/project/mpi4py/) to discover the MPI environment (e.g.,
rank, world size) and properly initialize torch distributed for training. In this
case you will explicitly invoke `python` to launch your model script instead of using
the `deepspeed` launcher, here is an example:
```bash
mpirun <mpi-args> python \
<client_entry.py> <client args> \
--deepspeed_mpi --deepspeed --deepspeed_config ds_config.json
```
backend.
If you want to use this feature of DeepSpeed, please ensure that mpi4py is
installed via `pip install mpi4py`.
To launch your training job with mpirun + DeepSpeed or with AzureML (which uses
mpirun as a launcher backend) you simply need to install the
[mpi4py](https://pypi.org/project/mpi4py/) python package. DeepSpeed will use
this to discover the MPI environment and pass the necessary state (e.g., world
size, rank) to the torch distributed backend.
If you are using model parallelism, pipeline parallelism, or otherwise require
torch.distributed calls before calling `deepspeed.initialize(..)` we provide
the same MPI support with an additional DeepSpeed API call. Replace your initial
`torch.distributed.init_process_group(..)` call with:
```python
deepspeed.init_distributed()
```
## Resource Configuration (single-node)
In the case that we are only running on a single node (with one or more GPUs)

Просмотреть файл

@ -171,5 +171,5 @@ else
pdcp -w $hosts dist/deepspeed*.whl $tmp_wheel_path/
pdsh -w $hosts "$PIP_SUDO $PIP_INSTALL $tmp_wheel_path/deepspeed*.whl"
pdsh -w $hosts "ds_report"
pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl; rmdir $tmp_wheel_path; fi"
pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl; rm $tmp_wheel_path/*.txt; rmdir $tmp_wheel_path; fi"
fi

Просмотреть файл

@ -5,6 +5,8 @@ import torch
import torch.distributed as dist
from torch.multiprocessing import Process
import deepspeed
import pytest
# Worker timeout *after* the first worker has completed.
@ -33,10 +35,12 @@ def distributed_test(world_size=2, backend='nccl'):
"""Initialize torch.distributed and execute the user function. """
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29503'
dist.init_process_group(backend=backend,
init_method='env://',
rank=local_rank,
world_size=num_procs)
os.environ['LOCAL_RANK'] = str(local_rank)
# NOTE: unit tests don't support multi-node so local_rank == global rank
os.environ['RANK'] = str(local_rank)
os.environ['WORLD_SIZE'] = str(num_procs)
deepspeed.init_distributed(dist_backend=backend)
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)