зеркало из https://github.com/microsoft/DeepSpeed.git
ZeRO-2 (#217)
Updates for ZeRO stage 2 + ZeRO stage 1 w. RS Co-authored-by: Tunji Ruwase <olruwase@microsoft.com> Co-authored-by: Samyam Rajbhandari <samyamr@microsoft.com> Co-authored-by: Shaden Smith <ShadenTSmith@gmail.com> Co-authored-by: Elton Zheng <eltonz@microsoft.com> Co-authored-by: Shaden Smith <Shaden.Smith@microsoft.com> Co-authored-by: yuxionghe <yuxhe@microsoft.com> Co-authored-by: Arash Ashari <arashari@microsoft.com>
This commit is contained in:
Родитель
c61e23b4b1
Коммит
f2ac7eafd5
|
@ -2,6 +2,7 @@
|
|||
.idea/
|
||||
*~
|
||||
*.swp
|
||||
*.log
|
||||
deepspeed/git_version_info.py
|
||||
|
||||
# Build + installation data
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 9e2c735f5aabe48395c03a276fa7a0c51f6d3025
|
||||
Subproject commit 274787a189b265814ed75dd5ddeae2dce026ea88
|
106
README.md
106
README.md
|
@ -1,23 +1,30 @@
|
|||
[![Build Status](https://dev.azure.com/DeepSpeedMSFT/DeepSpeed/_apis/build/status/microsoft.DeepSpeed?branchName=master)](https://dev.azure.com/DeepSpeedMSFT/DeepSpeed/_build/latest?definitionId=1&branchName=master)
|
||||
[![License MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/Microsoft/DeepSpeed/blob/master/LICENSE)
|
||||
|
||||
[DeepSpeed](https://www.deepspeed.ai/) is a deep learning optimization library that makes distributed training easy,
|
||||
efficient, and effective.
|
||||
[DeepSpeed](https://www.deepspeed.ai/) is a deep learning optimization
|
||||
library that makes distributed training easy, efficient, and effective.
|
||||
|
||||
<p align="center"><i><b>10x Larger Models</b></i></p>
|
||||
<p align="center"><i><b>5x Faster Training</b></i></p>
|
||||
<p align="center"><i><b>10x Faster Training</b></i></p>
|
||||
<p align="center"><i><b>Minimal Code Change</b></i></p>
|
||||
|
||||
DeepSpeed can train deep learning models with over a hundred billion parameters on current
|
||||
generation of GPU clusters, while achieving over 5x in system performance
|
||||
generation of GPU clusters, while achieving over 10x in system performance
|
||||
compared to the state-of-art. Early adopters of DeepSpeed have already produced
|
||||
a language model (LM) with over 17B parameters called
|
||||
[Turing-NLG](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft),
|
||||
establishing a new SOTA in the LM category.
|
||||
|
||||
**_For further documentation, tutorials, and technical deep-dives please see [deepspeed.ai](https://www.deepspeed.ai/)!_**
|
||||
|
||||
|
||||
# News
|
||||
* [Turing-NLG: A 17-billion-parameter language model by Microsoft](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/)
|
||||
* [ZeRO & DeepSpeed: New system optimizations enable training models with over 100 billion parameters](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/)
|
||||
* [2020/05/19] [ZeRO-2 empowers training models as large as 170 billion parameters up to 10x faster compared to state-of-the-art](https://www.deepspeed.ai/news/2020/05/19/zero-stage2.html)
|
||||
<span style="color:dodgerblue">**[_NEW_]**</span>
|
||||
* [2020/05/19] [DeepSpeed optimizes transformer kernels to achieve world’s fastest BERT training record: 44 minutes on 1024 NVIDIA V100 GPUs](https://www.deepspeed.ai/news/2020/05/19/bert-record.html)
|
||||
<span style="color:dodgerblue">**[_NEW_]**</span>
|
||||
* [2020/02/13] [Turing-NLG: A 17-billion-parameter language model by Microsoft](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/)
|
||||
* [2020/02/13] [ZeRO & DeepSpeed: New system optimizations enable training models with over 100 billion parameters](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/)
|
||||
|
||||
|
||||
# Table of Contents
|
||||
|
@ -39,93 +46,6 @@ a large model easily runs out of memory with pure data parallelism and it is
|
|||
difficult to use model parallelism. DeepSpeed addresses these challenges to
|
||||
accelerate model development *and* training.
|
||||
|
||||
## Distributed, Effective, and Efficient Training with Ease
|
||||
The DeepSpeed API is a lightweight wrapper on [PyTorch](https://pytorch.org/). This
|
||||
means that you can use everything you love in PyTorch and without learning a new
|
||||
platform. In addition, DeepSpeed manages all of the boilerplate state-of-the-art
|
||||
training techniques, such as distributed training, mixed precision, gradient
|
||||
accumulation, and checkpoints so that you can focus on your model development. Most
|
||||
importantly, you can leverage the distinctive efficiency and effectiveness benefit of
|
||||
DeepSpeed to boost speed and scale with just a few lines of code changes to your PyTorch
|
||||
models.
|
||||
|
||||
## Speed
|
||||
DeepSpeed achieves high performance and fast convergence through a combination of
|
||||
efficiency optimizations on compute/communication/memory/IO and effectiveness
|
||||
optimizations on advanced hyperparameter tuning and optimizers. For example:
|
||||
|
||||
* DeepSpeed trains BERT-large to parity in 14 hours using 64 GPUs (4 DGX-2 boxes) and in
|
||||
3.7 hours using 256 GPUs (16 DGX-2 boxes).
|
||||
|
||||
**BERT-large Training Times**
|
||||
|
||||
| Devices | Source | Training Time (hours) |
|
||||
| ------------- | --------- | ---------------------:|
|
||||
| 64 TPUs | Google | 96 |
|
||||
| 64 V100 GPUs | DeepSpeed | **14** |
|
||||
| 256 V100 GPUs | NVIDIA | 3.9 |
|
||||
| 256 V100 GPUs | DeepSpeed | **3.7** |
|
||||
|
||||
*Read more*: [BERT pre-training tutorial](https://www.deepspeed.ai/tutorials/bert-pretraining/)
|
||||
|
||||
* DeepSpeed trains GPT2 (1.5 billion parameters) 3.75x faster than state-of-art, NVIDIA
|
||||
Megatron on Azure GPUs.
|
||||
|
||||
*Read more*: [GPT tutorial](https://www.deepspeed.ai/tutorials/megatron/)
|
||||
|
||||
|
||||
|
||||
## Memory efficiency
|
||||
DeepSpeed provides memory-efficient data parallelism and enables training models without
|
||||
model parallelism. For example, DeepSpeed can train models with up to 6 billion parameters on
|
||||
NVIDIA V100 GPUs with 32GB of device memory. In comparison, existing frameworks (e.g.,
|
||||
PyTorch's Distributed Data Parallel) run out of memory with 1.5 billion parameter models.
|
||||
|
||||
DeepSpeed reduces the training memory footprint through a novel solution called Zero
|
||||
Redundancy Optimizer (ZeRO). Unlike basic data parallelism where memory states are
|
||||
replicated across data-parallel processes, ZeRO partitions model states to save
|
||||
significant memory. The current implementation (stage 1 of ZeRO) reduces memory by up to
|
||||
4x relative to the state-of-art. You can read more about ZeRO in our [paper](https://arxiv.org/abs/1910.02054).
|
||||
|
||||
With this impressive memory reduction, early adopters of DeepSpeed have already
|
||||
produced a language model (LM) with over 17B parameters called
|
||||
[Turing-NLG](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft),
|
||||
establishing a new SOTA in the LM category.
|
||||
|
||||
|
||||
## Scalability
|
||||
DeepSpeed supports efficient data parallelism, model parallelism, and their
|
||||
combination. ZeRO boosts the scaling capability and efficiency further.
|
||||
* DeepSpeed provides system support to run models up to 100 billion parameters,
|
||||
10x larger than the state-of-art (8 billion NVIDIA GPT, 11 billion Google T5).
|
||||
* DeepSpeed can run large models more efficiently, up to 6x faster for models with
|
||||
various sizes spanning 1.5B to 100B. More specifically, the data parallelism powered by ZeRO
|
||||
is complementary and can be combined with different types of model parallelism. It allows
|
||||
DeepSpeed to fit models using lower degree of model parallelism and higher batch size, offering
|
||||
significant performance gains compared to using model parallelism alone.
|
||||
|
||||
*Read more*: [technical report](https://arxiv.org/abs/1910.02054)
|
||||
and [GPT tutorial](https://www.deepspeed.ai/tutorials/megatron/)
|
||||
|
||||
![DeepSpeed-vs-Megatron](./docs/assets/images/DeepSpeed-vs-Megatron.png)
|
||||
<p align="center">
|
||||
<em>The figure depicts system throughput improvements of DeepSpeed (combining ZeRO-powered data parallelism with model parallelism of NVIDIA Megatron-LM) over using Megatron-LM alone.</em>
|
||||
</p>
|
||||
|
||||
|
||||
## Fast convergence for effectiveness
|
||||
DeepSpeed supports advanced hyperparameter tuning and large batch size
|
||||
optimizers such as [LAMB](https://arxiv.org/abs/1904.00962). These improve the
|
||||
effectiveness of model training and reduce the number of samples required to
|
||||
convergence to desired accuracy.
|
||||
|
||||
*Read more*: [Tuning tutorial](https://www.deepspeed.ai/tutorials/1Cycle/) and [BERT pre-training tutorial](https://www.deepspeed.ai/tutorials/bert-pretraining/)
|
||||
|
||||
|
||||
## Usability
|
||||
Only a few lines of code changes are needed to enable a PyTorch model to use DeepSpeed and ZeRO. Compared to current model parallelism libraries, DeepSpeed does not require a code redesign or model refactoring. It also does not put limitations on model dimensions (such as number of attention heads, hidden sizes, and others), batch size, or any other training parameters. For models of up to six billion parameters, you can use ZeRO-powered data parallelism conveniently without requiring model parallelism, while in contrast, standard data parallelism will run out of memory for models with more than 1.3 billion parameters. In addition, DeepSpeed conveniently supports flexible combination of ZeRO-powered data parallelism with custom model parallelisms, such as tensor slicing of NVIDIA's Megatron-LM.
|
||||
|
||||
|
||||
# Features
|
||||
|
||||
Below we provide a brief feature list, see our detailed [feature
|
||||
|
|
|
@ -35,11 +35,6 @@ jobs:
|
|||
pre-commit run --all-files
|
||||
displayName: 'Formatting checks'
|
||||
|
||||
- script: |
|
||||
pip install --user pylint
|
||||
pylint --exit-zero deepspeed/
|
||||
displayName: 'Code linter'
|
||||
|
||||
- script: |
|
||||
pytest --forked --verbose tests/unit/
|
||||
displayName: 'Unit tests'
|
||||
|
|
|
@ -6,6 +6,8 @@ from deepspeed.pt.deepspeed_light import DeepSpeedLight
|
|||
from deepspeed.pt.deepspeed_light import ADAM_OPTIMIZER, LAMB_OPTIMIZER
|
||||
from deepspeed.pt.deepspeed_lr_schedules import add_tuning_arguments
|
||||
|
||||
import deepspeed.pt.deepspeed_checkpointing as checkpointing
|
||||
|
||||
try:
|
||||
from deepspeed.git_version_info import git_hash, git_branch
|
||||
except ImportError:
|
||||
|
@ -14,7 +16,7 @@ except ImportError:
|
|||
|
||||
# Export version information
|
||||
__version_major__ = 0
|
||||
__version_minor__ = 1
|
||||
__version_minor__ = 2
|
||||
__version_patch__ = 0
|
||||
__version__ = '.'.join(
|
||||
map(str,
|
||||
|
@ -33,7 +35,8 @@ def initialize(args,
|
|||
lr_scheduler=None,
|
||||
mpu=None,
|
||||
dist_init_required=None,
|
||||
collate_fn=None):
|
||||
collate_fn=None,
|
||||
config_params=None):
|
||||
"""Initialize the DeepSpeed Engine.
|
||||
|
||||
Arguments:
|
||||
|
@ -91,7 +94,8 @@ def initialize(args,
|
|||
lr_scheduler=lr_scheduler,
|
||||
mpu=mpu,
|
||||
dist_init_required=dist_init_required,
|
||||
collate_fn=collate_fn)
|
||||
collate_fn=collate_fn,
|
||||
config_params=config_params)
|
||||
|
||||
return_items = [
|
||||
engine,
|
||||
|
|
|
@ -0,0 +1,724 @@
|
|||
'''
|
||||
Copyright (c) Microsoft Corporation
|
||||
Licensed under the MIT license.
|
||||
|
||||
Use to partition the activations stored for backward propagation
|
||||
Therefore reduces the memory consumption
|
||||
Also implements CPU checkpointing and contiguous memory checkpointing
|
||||
Reduces memory consumption and memory fragmentation
|
||||
|
||||
Code for rng checkpointing taken from NVIDIA Megatron-LM mpu/random.py
|
||||
b886b7bb972afe72bac0f5de4f42a4a7bae8ebef
|
||||
'''
|
||||
|
||||
# Parts of the code here are adapted from PyTorch
|
||||
# repo: https://github.com/pytorch/pytorch
|
||||
import contextlib
|
||||
import torch.distributed as dist
|
||||
import torch
|
||||
from torch import _C
|
||||
from torch.cuda import _lazy_call, device as device_ctx_manager
|
||||
from deepspeed.pt.deepspeed_timer import SynchronizedWallClockTimer as Timers
|
||||
import torch.distributed as dist
|
||||
from deepspeed.pt.deepspeed_config import DeepSpeedConfig
|
||||
|
||||
#DeepSpeed Checkpointing Enabled or Disabled
|
||||
deepspeed_checkpointing_enabled = False
|
||||
|
||||
#MP parameters
|
||||
mpu = None
|
||||
mp_rank = None
|
||||
mp_size = None
|
||||
mp_group = None
|
||||
|
||||
#Model Parameters
|
||||
num_layers = None
|
||||
|
||||
#Checkpointing buffers
|
||||
contiguous_data_buffers = []
|
||||
data_offsets = []
|
||||
|
||||
contiguous_size_buffers = []
|
||||
size_offsets = []
|
||||
|
||||
timers = None
|
||||
|
||||
#optimization flags
|
||||
PARTITION_ACTIVATIONS = False
|
||||
PA_TO_CPU = False
|
||||
CONTIGUOUS_CHECKPOINTING = False
|
||||
SYNCHRONIZE = False
|
||||
PROFILE_TIME = False
|
||||
|
||||
|
||||
def see_memory_usage(message, force=False):
|
||||
#return
|
||||
if not force:
|
||||
return
|
||||
#dist.barrier()
|
||||
if dist.get_rank() == 0:
|
||||
print(message)
|
||||
print("Memory Allocated ",
|
||||
torch.cuda.memory_allocated() / (1024 * 1024 * 1024),
|
||||
"GigaBytes")
|
||||
print("Max Memory Allocated ",
|
||||
torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),
|
||||
"GigaBytes")
|
||||
print("Cache Allocated ",
|
||||
torch.cuda.memory_cached() / (1024 * 1024 * 1024),
|
||||
"GigaBytes")
|
||||
print("Max cache Allocated ",
|
||||
torch.cuda.max_memory_cached() / (1024 * 1024 * 1024),
|
||||
"GigaBytes")
|
||||
print(" ")
|
||||
#input("Press Any Key To Continue ..")
|
||||
|
||||
|
||||
# Default name for the model parallel rng tracker.
|
||||
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
|
||||
transport_stream = None
|
||||
cuda_device = None
|
||||
|
||||
|
||||
def detach_variable(inputs, device=None):
|
||||
if isinstance(inputs, tuple):
|
||||
out = []
|
||||
for inp in inputs:
|
||||
if not isinstance(inp, torch.Tensor):
|
||||
out.append(inp)
|
||||
continue
|
||||
|
||||
requires_grad = inp.requires_grad
|
||||
|
||||
if device is not None:
|
||||
x = inp.to(device=device)
|
||||
else:
|
||||
x = inp
|
||||
|
||||
x = x.detach()
|
||||
x.requires_grad = requires_grad
|
||||
out.append(x)
|
||||
return tuple(out)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Only tuple of tensors is supported. Got Unsupported input type: ",
|
||||
type(inputs).__name__)
|
||||
|
||||
|
||||
def _set_cuda_rng_state(new_state, device=-1):
|
||||
"""Sets the random number generator state of the current GPU.
|
||||
|
||||
Argumentss:
|
||||
new_state (torch.ByteTensor): The desired state
|
||||
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
|
||||
with a single change: the input state is not cloned. Cloning caused
|
||||
major performance issues for +4 GPU cases.
|
||||
"""
|
||||
if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
|
||||
# older PyTorch
|
||||
def cb():
|
||||
with device_ctx_manager(device):
|
||||
_C._cuda_setRNGState(new_state)
|
||||
else:
|
||||
# newer PyTorch
|
||||
if device == -1:
|
||||
device = torch.device('cuda')
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
elif isinstance(device, int):
|
||||
device = torch.device('cuda', device)
|
||||
|
||||
def cb():
|
||||
idx = device.index
|
||||
if idx is None:
|
||||
idx = torch.cuda.current_device()
|
||||
default_generator = torch.cuda.default_generators[idx]
|
||||
default_generator.set_state(new_state)
|
||||
|
||||
_lazy_call(cb)
|
||||
|
||||
|
||||
class CudaRNGStatesTracker:
|
||||
"""Tracker for the cuda RNG states.
|
||||
|
||||
Using the `add` method, a cuda rng state is initialized based on
|
||||
the input `seed` and is assigned to `name`. Later, by forking the
|
||||
rng state, we can perform operations and return to our starting
|
||||
cuda state.
|
||||
"""
|
||||
def __init__(self):
|
||||
# Map from a string name to the cuda rng state.
|
||||
self.states_ = {}
|
||||
# Seeds are just for book keeping and ensure no seed is set twice.
|
||||
self.seeds_ = set()
|
||||
|
||||
def reset(self):
|
||||
"""Set to the initial state (no tracker)."""
|
||||
self.states_ = {}
|
||||
self.seeds_ = set()
|
||||
|
||||
def get_states(self):
|
||||
"""Get rng states. Copy the dictionary so we have direct
|
||||
pointers to the states, not just a pointer to the dictionary."""
|
||||
states = {}
|
||||
for name in self.states_:
|
||||
states[name] = self.states_[name]
|
||||
return states
|
||||
|
||||
def set_states(self, states):
|
||||
"""Set the rng states. For efficiency purposes, we do not check
|
||||
the size of seed for compatibility."""
|
||||
self.states_ = states
|
||||
|
||||
def add(self, name, seed):
|
||||
"""Track the rng state."""
|
||||
# Check seed is not already used.
|
||||
if seed in self.seeds_:
|
||||
raise Exception('seed {} already exists'.format(seed))
|
||||
self.seeds_.add(seed)
|
||||
# Check that state is not already defined.
|
||||
if name in self.states_:
|
||||
raise Exception('cuda rng state {} already exists'.format(name))
|
||||
# Get the current rng state.
|
||||
orig_rng_state = torch.cuda.get_rng_state()
|
||||
# Set the new state and store it.
|
||||
torch.cuda.manual_seed(seed)
|
||||
self.states_[name] = torch.cuda.get_rng_state()
|
||||
# Reset rng state to what it was.
|
||||
_set_cuda_rng_state(orig_rng_state)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
|
||||
"""Fork the cuda rng state, perform operations, and exit with
|
||||
the original state."""
|
||||
# Check if we have added the state
|
||||
if name not in self.states_:
|
||||
raise Exception('cuda rng state {} is not added'.format(name))
|
||||
# Store current rng state.
|
||||
orig_cuda_rng_state = torch.cuda.get_rng_state()
|
||||
# Set rng state to the desired one
|
||||
_set_cuda_rng_state(self.states_[name])
|
||||
# Do the stuff we wanted to do.
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Update the current rng state for later use.
|
||||
self.states_[name] = torch.cuda.get_rng_state()
|
||||
# And set the state to the original state we started with.
|
||||
_set_cuda_rng_state(orig_cuda_rng_state)
|
||||
|
||||
|
||||
# RNG tracker object.
|
||||
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
|
||||
|
||||
|
||||
def get_cuda_rng_tracker():
|
||||
"""Get cuda rng tracker."""
|
||||
return _CUDA_RNG_STATE_TRACKER
|
||||
|
||||
|
||||
def model_parallel_cuda_manual_seed(seed):
|
||||
"""Initialize model parallel cuda seed.
|
||||
|
||||
This function should be called after the model parallel is
|
||||
initialized. Also, no torch.cuda.manual_seed should be called
|
||||
after this function. Basically, this is replacement for that
|
||||
function.
|
||||
Two set of RNG states are tracked:
|
||||
default state: This is for data parallelism and is the same among a
|
||||
set of model parallel GPUs but different across
|
||||
different model paralle groups. This is used for
|
||||
example for dropout in the non-model-parallel regions.
|
||||
model-parallel state: This state is different among a set of model
|
||||
parallel GPUs, but the same across data parallel
|
||||
groups. This is used for example for dropout in
|
||||
model parallel regions.
|
||||
"""
|
||||
global mpu
|
||||
# 2718 is just for fun and any POSITIVE value will work.
|
||||
offset = seed + 2718
|
||||
model_parallel_seed = offset + mpu.get_model_parallel_rank()
|
||||
# Data parallel gets the original sedd.
|
||||
data_parallel_seed = seed
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> initializing model parallel cuda seeds on global rank {}, '
|
||||
'model parallel rank {}, and data parallel rank {} with '
|
||||
'model parallel seed: {} and data parallel seed: {}'.format(
|
||||
torch.distributed.get_rank(),
|
||||
mpu.get_model_parallel_rank(),
|
||||
mpu.get_data_parallel_rank(),
|
||||
model_parallel_seed,
|
||||
data_parallel_seed),
|
||||
flush=True)
|
||||
_CUDA_RNG_STATE_TRACKER.reset()
|
||||
# Set the default state.
|
||||
torch.cuda.manual_seed(data_parallel_seed)
|
||||
# and model parallel state.
|
||||
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed)
|
||||
|
||||
|
||||
def get_partition_start(item):
|
||||
global mp_rank, mp_size, mp_group
|
||||
size = item.numel()
|
||||
partition_size = size / mp_size
|
||||
start = partition_size * mp_rank
|
||||
return int(start)
|
||||
|
||||
|
||||
def get_partition_size(item):
|
||||
global mp_rank, mp_size, mp_group
|
||||
size = item.numel()
|
||||
assert size % mp_size == 0, "Doesn't handle if partition activation if item is not divisible by mp size"
|
||||
partition_size = size / mp_size
|
||||
return int(partition_size)
|
||||
|
||||
|
||||
def get_full_inputs(tensors, device=None):
|
||||
inputs = []
|
||||
num_args = int(len(tensors) / 2)
|
||||
for i in range(num_args - 1):
|
||||
|
||||
item = tensors[2 * i]
|
||||
size = tensors[2 * i + 1]
|
||||
|
||||
partition_size = item.numel()
|
||||
tensor_size = partition_size * mp_size
|
||||
if device is not None:
|
||||
flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=device)
|
||||
else:
|
||||
flat_tensor = torch.zeros([tensor_size],
|
||||
dtype=item.dtype,
|
||||
device=item.device)
|
||||
partitions = []
|
||||
for i in range(mp_size):
|
||||
part_i = flat_tensor.narrow(0, partition_size * i, partition_size)
|
||||
if i == mp_rank:
|
||||
part_i.copy_(item)
|
||||
partitions.append(part_i)
|
||||
if mp_group is not None:
|
||||
dist.all_gather(partitions, partitions[mp_rank], group=mp_group)
|
||||
input_tensor = flat_tensor.view(list(size.numpy()))
|
||||
item.data = input_tensor.data
|
||||
|
||||
inputs.append(item)
|
||||
inputs.append(tensors[-2])
|
||||
|
||||
return tuple(inputs)
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
"""This function is adapted from torch.utils.checkpoint with
|
||||
two main changes:
|
||||
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
|
||||
2) the states in the model parallel tracker are also properly
|
||||
tracked/set/reset.
|
||||
3) Performance activation partitioning, contiguous memory optimization
|
||||
4) CPU Checkpointing
|
||||
5) Profile forward and backward functions
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, *args):
|
||||
global mpu, timers, SYNCHRONIZE, PROFILE_TIME
|
||||
|
||||
if SYNCHRONIZE:
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if timers is None and PROFILE_TIME:
|
||||
timers = Timers()
|
||||
|
||||
if PROFILE_TIME:
|
||||
timers('forward').start()
|
||||
|
||||
ctx.run_function = run_function
|
||||
global num_layers
|
||||
global mp_rank, mp_size, mp_group
|
||||
global contiguous_data_buffers, contiguous_size_buffers
|
||||
global data_offsets, size_offsets
|
||||
if mp_rank is None:
|
||||
if mpu is not None:
|
||||
mp_rank = mpu.get_model_parallel_rank()
|
||||
mp_size = mpu.get_model_parallel_world_size()
|
||||
mp_group = mpu.get_model_parallel_group()
|
||||
else:
|
||||
mp_rank = 0
|
||||
mp_size = 1
|
||||
mp_group = None
|
||||
|
||||
global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset
|
||||
|
||||
if cuda_device is None:
|
||||
see_memory_usage("First Forward Begining", force=True)
|
||||
if dist.get_rank() == 0:
|
||||
print(f"Activation Checkpointing Information")
|
||||
print(
|
||||
f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {PA_TO_CPU}"
|
||||
)
|
||||
print(
|
||||
f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers"
|
||||
)
|
||||
print(f"----Synchronization {SYNCHRONIZE}")
|
||||
print(f"----Profiling {PROFILE_TIME}")
|
||||
|
||||
cuda_device = torch.cuda.current_device()
|
||||
transport_stream = torch.cuda.Stream(device=cuda_device)
|
||||
|
||||
if PARTITION_ACTIVATIONS:
|
||||
#inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]]
|
||||
#inputs.append(args[-1])
|
||||
|
||||
inputs = []
|
||||
for i, item in enumerate(args[:-1]):
|
||||
partition_size = get_partition_size(item)
|
||||
partition = item.detach().contiguous().view(-1).narrow(
|
||||
0,
|
||||
get_partition_start(item),
|
||||
partition_size).clone()
|
||||
|
||||
if CONTIGUOUS_CHECKPOINTING:
|
||||
buffer_device = torch.device(
|
||||
'cpu') if PA_TO_CPU else partition.device
|
||||
|
||||
if i >= len(contiguous_data_buffers):
|
||||
tensor_list = [
|
||||
torch.tensor(()).new_empty([partition_size],
|
||||
dtype=partition.dtype,
|
||||
device=buffer_device)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
contiguous_data_buffers.append(tensor_list)
|
||||
data_offsets.append(0)
|
||||
elif contiguous_data_buffers[i] is None:
|
||||
tensor_list = [
|
||||
torch.tensor(()).new_empty([partition_size],
|
||||
dtype=partition.dtype,
|
||||
device=buffer_device)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
contiguous_data_buffers[i] = tensor_list
|
||||
data_offsets[i] = 0
|
||||
|
||||
contiguous_partition = contiguous_data_buffers[i][
|
||||
data_offsets[i]].data.copy_(partition.data)
|
||||
data_offsets[i] = data_offsets[i] + 1
|
||||
inputs.append(contiguous_partition)
|
||||
else:
|
||||
partition = partition.cpu() if PA_TO_CPU else partition
|
||||
inputs.append(partition)
|
||||
|
||||
inputs.append(args[-1])
|
||||
|
||||
#just in case something funky is happening such as reuse of inputs
|
||||
inputs_cuda = [item.to(cuda_device) for item in args]
|
||||
|
||||
# Copy the rng states.
|
||||
ctx.fwd_cpu_rng_state = torch.get_rng_state()
|
||||
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
|
||||
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
|
||||
|
||||
#ctx.save_for_backward(*args)
|
||||
with torch.no_grad():
|
||||
outputs = run_function(*inputs_cuda)
|
||||
|
||||
del inputs_cuda
|
||||
|
||||
#with torch.cuda.stream(transport_stream):
|
||||
#if PARTITION_ACTIVATIONS:
|
||||
# new_args = []
|
||||
# for arg, inp in zip(args,inputs):
|
||||
# size= torch.tensor(arg.size())
|
||||
# arg.data = inp.data
|
||||
# new_args.append(arg)
|
||||
# new_args.append(size)
|
||||
# ctx.save_for_backward(*new_args)
|
||||
|
||||
if PARTITION_ACTIVATIONS:
|
||||
new_args = []
|
||||
for i, (arg, inp) in enumerate(zip(args, inputs)):
|
||||
size = torch.tensor(arg.size())
|
||||
|
||||
arg.data = inp.data
|
||||
new_args.append(arg)
|
||||
|
||||
if CONTIGUOUS_CHECKPOINTING:
|
||||
numel = size.numel()
|
||||
if i >= len(contiguous_size_buffers):
|
||||
tmp = torch.tensor(())
|
||||
contiguous_size_buffers.append(
|
||||
tmp.new_empty([numel * num_layers],
|
||||
dtype=size.dtype,
|
||||
device=size.device))
|
||||
size_offsets.append(0)
|
||||
elif contiguous_size_buffers[i] is None:
|
||||
tmp = torch.tensor(())
|
||||
contiguous_size_buffers[i] = tmp.new_empty([numel * num_layers],
|
||||
dtype=size.dtype,
|
||||
device=size.device)
|
||||
size_offsets[i] = 0
|
||||
|
||||
contiguous_size = contiguous_size_buffers[i].narrow(
|
||||
0,
|
||||
size_offsets[i],
|
||||
numel).data.copy_(size.data)
|
||||
contiguous_size = contiguous_size.view_as(size)
|
||||
size_offsets[i] = size_offsets[i] + numel
|
||||
new_args.append(contiguous_size)
|
||||
else:
|
||||
new_args.append(size)
|
||||
#if dist.get_rank() == 0:
|
||||
# print (f"The stored tensor is {contiguous_size} and orginal one is {size} ")
|
||||
|
||||
ctx.save_for_backward(*new_args)
|
||||
else:
|
||||
ctx.save_for_backward(*args)
|
||||
if PROFILE_TIME:
|
||||
timers('forward').stop()
|
||||
timers.log(['forward'])
|
||||
if SYNCHRONIZE:
|
||||
torch.cuda.synchronize()
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
global timers
|
||||
#see_memory_usage("In backward", force=True)
|
||||
#removing pointers to the contiguous buffer memory
|
||||
#so that they can be garbage collected once the checkpoints
|
||||
#have been used
|
||||
if SYNCHRONIZE:
|
||||
torch.cuda.synchronize()
|
||||
if PROFILE_TIME:
|
||||
timers('backward').start()
|
||||
|
||||
if CONTIGUOUS_CHECKPOINTING:
|
||||
global data_offsets, size_offsets
|
||||
global contiguous_data_buffers, contiguous_size_buffers
|
||||
|
||||
for buffers in contiguous_data_buffers:
|
||||
buffers = []
|
||||
|
||||
#frees up all the pointers to the checkpoints except for the ones
|
||||
#stored by save for backward
|
||||
contiguous_data_buffers = []
|
||||
contiguous_size_buffers = []
|
||||
data_offsets = []
|
||||
size_offsets = []
|
||||
|
||||
#see_memory_usage("In backward checkpointing code", force=True)
|
||||
if not torch.autograd._is_checkpoint_valid():
|
||||
raise RuntimeError("Checkpointing is not compatible with .grad(), "
|
||||
"please use .backward() if possible")
|
||||
|
||||
global cuda_device, transport_stream, PARTITION_ACTIVATIONS
|
||||
|
||||
if PARTITION_ACTIVATIONS:
|
||||
#with torch.cuda.stream(transport_stream):
|
||||
inputs = get_full_inputs(ctx.saved_tensors,
|
||||
device=cuda_device if PA_TO_CPU else None)
|
||||
detached_inputs = detach_variable(inputs)
|
||||
else:
|
||||
inputs = ctx.saved_tensors
|
||||
detached_inputs = detach_variable(inputs)
|
||||
|
||||
# Store the current states.
|
||||
bwd_cpu_rng_state = torch.get_rng_state()
|
||||
bwd_cuda_rng_state = torch.cuda.get_rng_state()
|
||||
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
|
||||
|
||||
# Set the states to what it used to be before the forward pass.
|
||||
torch.set_rng_state(ctx.fwd_cpu_rng_state)
|
||||
_set_cuda_rng_state(ctx.fwd_cuda_rng_state)
|
||||
get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
|
||||
|
||||
# if PARTITION_ACTIVATIONS:
|
||||
# current_stream=torch.cuda.current_stream()
|
||||
# current_stream.wait_stream(transport_stream)
|
||||
|
||||
with torch.enable_grad():
|
||||
outputs = ctx.run_function(*detached_inputs)
|
||||
|
||||
# Set the states back to what it was at the start of this function.
|
||||
torch.set_rng_state(bwd_cpu_rng_state)
|
||||
_set_cuda_rng_state(bwd_cuda_rng_state)
|
||||
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
|
||||
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
outputs = (outputs, )
|
||||
torch.autograd.backward(outputs, args)
|
||||
|
||||
if PROFILE_TIME:
|
||||
timers('backward').stop()
|
||||
timers.log(['backward'])
|
||||
if SYNCHRONIZE:
|
||||
torch.cuda.synchronize()
|
||||
return (None, ) + tuple(inp.grad for inp in detached_inputs)
|
||||
|
||||
|
||||
def checkpoint(function, *args):
|
||||
"""Checkpoint a model or part of the model.
|
||||
This has been directly copied from torch.utils.checkpoint. """
|
||||
return CheckpointFunction.apply(function, *args)
|
||||
|
||||
|
||||
def partition_activations_in_checkpoint(partition_activation):
|
||||
global PARTITION_ACTIVATIONS
|
||||
PARTITION_ACTIVATIONS = partition_activation
|
||||
if dist.get_rank() == 0:
|
||||
print(f"**************Partition Activations {PARTITION_ACTIVATIONS}************")
|
||||
|
||||
|
||||
def set_num_layers(nlayers):
|
||||
global num_layers
|
||||
num_layers = nlayers
|
||||
|
||||
|
||||
def reset():
|
||||
"""Resets memory buffers related to contiguous memory optimizations.
|
||||
Should be called during eval when multiple forward propagations are
|
||||
computed without any backward propagation that usually clears these
|
||||
buffers.
|
||||
Arguments:
|
||||
None
|
||||
|
||||
Return:
|
||||
None
|
||||
"""
|
||||
if CONTIGUOUS_CHECKPOINTING:
|
||||
global data_offsets, size_offsets
|
||||
global contiguous_data_buffers, contiguous_size_buffers
|
||||
|
||||
for buffers in contiguous_data_buffers:
|
||||
buffers = []
|
||||
|
||||
#frees up all the pointers to the checkpoints except for the ones
|
||||
#stored by save for backward
|
||||
contiguous_data_buffers = []
|
||||
contiguous_size_buffers = []
|
||||
data_offsets = []
|
||||
size_offsets = []
|
||||
|
||||
|
||||
def _configure_using_config_file(deepspeed_config):
|
||||
global num_layers, PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
|
||||
PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME
|
||||
|
||||
config = DeepSpeedConfig(deepspeed_config).activation_checkpointing_config
|
||||
print(config.repr())
|
||||
PARTITION_ACTIVATIONS = config.partition_activations
|
||||
CONTIGUOUS_CHECKPOINTING = config.contiguous_memory_optimization
|
||||
num_layers = config.number_checkpoints
|
||||
PA_TO_CPU = config.cpu_checkpointing
|
||||
SYNCHRONIZE = config.synchronize_checkpoint_boundary
|
||||
PROFILE_TIME = config.profile
|
||||
|
||||
|
||||
def _configure_defaults():
|
||||
|
||||
global mpu, num_layers, deepspeed_checkpointing_enabled
|
||||
|
||||
global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
|
||||
PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME
|
||||
|
||||
PARTITION_ACTIVATIONS = False
|
||||
CONTIGUOUS_CHECKPOINTING = False
|
||||
num_layers = False
|
||||
PA_TO_CPU = False
|
||||
SYNCHRONIZE = False
|
||||
PROFILE_TIME = False
|
||||
deepspeed_checkpointing_enabled = True
|
||||
|
||||
|
||||
def configure(
|
||||
mpu_,
|
||||
deepspeed_config=None,
|
||||
partition_activations=None,
|
||||
contiguous_checkpointing=None,
|
||||
num_checkpoints=None,
|
||||
checkpoint_in_cpu=None,
|
||||
synchronize=None,
|
||||
profile=None,
|
||||
):
|
||||
"""Configure DeepSpeed Activation Checkpointing.
|
||||
|
||||
Arguments:
|
||||
mpu_: Optional: An object that implements the following methods
|
||||
get_model_parallel_rank/group/world_size, and get_data_parallel_rank/group/world_size
|
||||
|
||||
deepspeed_config: Optional: DeepSpeed Config json file when provided will be used to
|
||||
configure DeepSpeed Activation Checkpointing
|
||||
|
||||
partition_activations: Optional: Partitions activation checkpoint across model parallel
|
||||
GPUs when enabled. By default False. Will overwrite deepspeed_config if provided
|
||||
|
||||
contiguous_checkpointing: Optional: Copies activation checkpoints to a contiguous memory
|
||||
buffer. Works only with homogeneous checkpoints when partition_activations is enabled.
|
||||
Must provide num_checkpoints. By default False. Will overwrite deepspeed_config if
|
||||
provided
|
||||
|
||||
num_checkpoints: Optional: Number of activation checkpoints stored during the forward
|
||||
propagation of the model. Used to calculate the buffer size for contiguous_checkpointing
|
||||
Will overwrite deepspeed_config if provided
|
||||
|
||||
checkpoint_in_cpu: Optional: Moves the activation checkpoint to CPU. Only works with
|
||||
partition_activation. Default is false. Will overwrite deepspeed_config if provided
|
||||
|
||||
synchronize: Optional: Performs torch.cuda.synchronize() at the beginning and end of
|
||||
each call to deepspeed.checkpointing.checkpoint for both forward and backward pass.
|
||||
By default false. Will overwrite deepspeed_config if provided
|
||||
|
||||
profile: Optional: Logs the forward and backward time for each
|
||||
deepspeed.checkpointing.checkpoint invocation. Will overwrite deepspeed_config
|
||||
if provided
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
global mpu, num_layers, deepspeed_checkpointing_enabled
|
||||
|
||||
global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
|
||||
PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME
|
||||
|
||||
_configure_defaults()
|
||||
|
||||
if deepspeed_config is not None:
|
||||
_configure_using_config_file(deepspeed_config)
|
||||
|
||||
if mpu_ is not None:
|
||||
mpu = mpu_
|
||||
|
||||
if partition_activations is not None:
|
||||
PARTITION_ACTIVATIONS = partition_activations
|
||||
|
||||
if contiguous_checkpointing is not None:
|
||||
CONTIGUOUS_CHECKPOINTING = contiguous_checkpointing
|
||||
|
||||
if num_checkpoints is not None:
|
||||
num_layers = num_checkpoints
|
||||
|
||||
if checkpoint_in_cpu is not None:
|
||||
PA_TO_CPU = checkpoint_in_cpu
|
||||
|
||||
if synchronize is not None:
|
||||
SYNCHRONIZE = synchronize
|
||||
|
||||
if profile is not None:
|
||||
PROFILE_TIME = profile
|
||||
|
||||
if PA_TO_CPU or CONTIGUOUS_CHECKPOINTING:
|
||||
assert PARTITION_ACTIVATIONS, "CPU Checkpointing/Contiguous Checkpointing is only availble with partitioned activations. Set partitioned activations to true in deepspeed config"
|
||||
if CONTIGUOUS_CHECKPOINTING:
|
||||
assert num_layers is not None, "Must specify the number of layers with contiguous memory checkpointing"
|
||||
|
||||
|
||||
def is_configured():
|
||||
"""True if deepspeed activation checkpointing has been configured
|
||||
by calling deepspeed.checkpointing.configure, else returns false
|
||||
|
||||
Arguments:
|
||||
None
|
||||
|
||||
Return:
|
||||
True of configured, else False
|
||||
"""
|
||||
global deepspeed_checkpointing_enabled
|
||||
return deepspeed_checkpointing_enabled
|
|
@ -0,0 +1,110 @@
|
|||
"""
|
||||
Copyright (c) Microsoft Corporation
|
||||
Licensed under the MIT license.
|
||||
"""
|
||||
|
||||
from deepspeed.pt.deepspeed_config_utils import get_scalar_param
|
||||
|
||||
#########################################
|
||||
# DeepSpeed Activation Checkpointing
|
||||
#########################################
|
||||
# Activation Checkpointing Allows to save memory by only keeping a select few
|
||||
#activations for the backpropagation.
|
||||
ACTIVATION_CHKPT_FORMAT = '''
|
||||
Activation Checkpointing should be configured as:
|
||||
"session_params": {
|
||||
"activation_checkpointing": {
|
||||
"partitioned_activations": [true|false],
|
||||
"number_checkpoints": 100,
|
||||
"contiguous_memory_optimization": [true|false],
|
||||
"cpu_checkpointing": [true|false]
|
||||
"profile": [true|false],
|
||||
"synchronize_checkpoint_boundary": [true|false],
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
ACT_CHKPT_PARTITION_ACTIVATIONS = 'partition_activations'
|
||||
ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT = False
|
||||
|
||||
ACT_CHKPT_NUMBER_CHECKPOINTS = 'number_checkpoints'
|
||||
ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT = None
|
||||
|
||||
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION = 'contiguous_memory_optimization'
|
||||
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT = False
|
||||
|
||||
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY = 'synchronize_checkpoint_boundary'
|
||||
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT = False
|
||||
|
||||
ACT_CHKPT_PROFILE = 'profile'
|
||||
ACT_CHKPT_PROFILE_DEFAULT = False
|
||||
|
||||
ACT_CHKPT_CPU_CHECKPOINTING = 'cpu_checkpointing'
|
||||
ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT = False
|
||||
|
||||
ACT_CHKPT = 'activation_checkpointing'
|
||||
|
||||
ACT_CHKPT_DEFAULT = {
|
||||
ACT_CHKPT_PARTITION_ACTIVATIONS: ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT,
|
||||
ACT_CHKPT_NUMBER_CHECKPOINTS: ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT,
|
||||
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION:
|
||||
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT,
|
||||
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY:
|
||||
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT,
|
||||
ACT_CHKPT_PROFILE: ACT_CHKPT_PROFILE_DEFAULT,
|
||||
ACT_CHKPT_CPU_CHECKPOINTING: ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT
|
||||
}
|
||||
|
||||
|
||||
class DeepSpeedActivationCheckpointingConfig(object):
|
||||
def __init__(self, param_dict):
|
||||
super(DeepSpeedActivationCheckpointingConfig, self).__init__()
|
||||
|
||||
self.partition_activations = None
|
||||
self.contiguous_memory_optimization = None
|
||||
self.cpu_checkpointing = None
|
||||
self.number_checkpoints = None
|
||||
self.synchronize_checkpoint_boundary = None
|
||||
self.profile = None
|
||||
|
||||
if ACT_CHKPT in param_dict.keys():
|
||||
act_chkpt_config_dict = param_dict[ACT_CHKPT]
|
||||
else:
|
||||
act_chkpt_config_dict = ACT_CHKPT_DEFAULT
|
||||
|
||||
self._initialize(act_chkpt_config_dict)
|
||||
|
||||
"""
|
||||
For json serialization
|
||||
"""
|
||||
|
||||
def repr(self):
|
||||
return self.__dict__
|
||||
|
||||
def _initialize(self, act_chkpt_config_dict):
|
||||
self.partition_activations = get_scalar_param(
|
||||
act_chkpt_config_dict,
|
||||
ACT_CHKPT_PARTITION_ACTIVATIONS,
|
||||
ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT)
|
||||
|
||||
self.contiguous_memory_optimization = get_scalar_param(
|
||||
act_chkpt_config_dict,
|
||||
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION,
|
||||
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT)
|
||||
|
||||
self.cpu_checkpointing = get_scalar_param(act_chkpt_config_dict,
|
||||
ACT_CHKPT_CPU_CHECKPOINTING,
|
||||
ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT)
|
||||
|
||||
self.number_checkpoints = get_scalar_param(act_chkpt_config_dict,
|
||||
ACT_CHKPT_NUMBER_CHECKPOINTS,
|
||||
ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT)
|
||||
|
||||
self.profile = get_scalar_param(act_chkpt_config_dict,
|
||||
ACT_CHKPT_PROFILE,
|
||||
ACT_CHKPT_PROFILE_DEFAULT)
|
||||
|
||||
self.synchronize_checkpoint_boundary = get_scalar_param(
|
||||
act_chkpt_config_dict,
|
||||
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY,
|
||||
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT)
|
|
@ -8,6 +8,9 @@ import logging
|
|||
import json
|
||||
from deepspeed.pt.deepspeed_constants import *
|
||||
from deepspeed.pt.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE
|
||||
from deepspeed.pt.deepspeed_config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys
|
||||
from deepspeed.pt.deepspeed_zero_config import DeepSpeedZeroConfig
|
||||
from deepspeed.pt.deepspeed_checkpointing_config import DeepSpeedActivationCheckpointingConfig
|
||||
|
||||
TENSOR_CORE_ALIGN_SIZE = 8
|
||||
ADAM_OPTIMIZER = 'adam'
|
||||
|
@ -15,13 +18,6 @@ LAMB_OPTIMIZER = 'lamb'
|
|||
DEEPSPEED_OPTIMIZERS = [ADAM_OPTIMIZER, LAMB_OPTIMIZER]
|
||||
|
||||
|
||||
def get_scalar_param(param_dict, param_name, param_default_value):
|
||||
if param_name in param_dict.keys():
|
||||
return param_dict[param_name]
|
||||
else:
|
||||
return param_default_value
|
||||
|
||||
|
||||
def get_fp16_enabled(param_dict):
|
||||
if FP16 in param_dict.keys():
|
||||
return get_scalar_param(param_dict[FP16], FP16_ENABLED, FP16_ENABLED_DEFAULT)
|
||||
|
@ -92,10 +88,20 @@ def get_sparse_gradients_enabled(param_dict):
|
|||
return get_scalar_param(param_dict, SPARSE_GRADIENTS, SPARSE_GRADIENTS_DEFAULT)
|
||||
|
||||
|
||||
def get_zero_enabled(param_dict):
|
||||
def get_zero_optimization(param_dict):
|
||||
return get_scalar_param(param_dict, ZERO_OPTIMIZATION, ZERO_OPTIMIZATION_DEFAULT)
|
||||
|
||||
|
||||
def get_zero_reduce_scatter(param_dict):
|
||||
return get_scalar_param(param_dict, ZERO_REDUCE_SCATTER, ZERO_REDUCE_SCATTER_DEFAULT)
|
||||
|
||||
|
||||
def get_zero_max_elements_per_comm(param_dict):
|
||||
return get_scalar_param(param_dict,
|
||||
ZERO_MAX_ELEMENTS_PER_COMM,
|
||||
ZERO_MAX_ELEMENTS_PER_COMM_DEFAULT)
|
||||
|
||||
|
||||
def get_allgather_size(param_dict):
|
||||
return get_scalar_param(param_dict,
|
||||
ALLGATHER_SIZE,
|
||||
|
@ -204,6 +210,10 @@ def get_wall_clock_breakdown(param_dict):
|
|||
WALL_CLOCK_BREAKDOWN_DEFAULT)
|
||||
|
||||
|
||||
def get_memory_breakdown(param_dict):
|
||||
return get_scalar_param(param_dict, MEMORY_BREAKDOWN, MEMORY_BREAKDOWN_DEFAULT)
|
||||
|
||||
|
||||
def get_tensorboard_enabled(param_dict):
|
||||
if TENSORBOARD in param_dict.keys():
|
||||
return get_scalar_param(param_dict[TENSORBOARD],
|
||||
|
@ -231,10 +241,39 @@ def get_tensorboard_job_name(param_dict):
|
|||
return TENSORBOARD_JOB_NAME_DEFAULT
|
||||
|
||||
|
||||
'''Write deepspeed config files by modifying basic templates.
|
||||
Can be used for quicly changing parameters via command line parameters.'''
|
||||
|
||||
|
||||
class DeepSpeedConfigWriter:
|
||||
def __init__(self, data=None):
|
||||
self.data = data if data is not None else {}
|
||||
|
||||
def add_config(self, key, value):
|
||||
self.data[key] = value
|
||||
|
||||
def load_config(self, filename):
|
||||
self.data = json.load(open(filename,
|
||||
'r'),
|
||||
object_pairs_hook=dict_raise_error_on_duplicate_keys)
|
||||
|
||||
def write_config(self, filename):
|
||||
with open(filename, 'w') as outfile:
|
||||
json.dump(self.data, outfile)
|
||||
|
||||
|
||||
class DeepSpeedConfig(object):
|
||||
def __init__(self, json_file, mpu=None):
|
||||
def __init__(self, json_file, mpu=None, param_dict=None):
|
||||
super(DeepSpeedConfig, self).__init__()
|
||||
self._param_dict = json.load(open(json_file, 'r'))
|
||||
|
||||
if param_dict is None:
|
||||
self._param_dict = json.load(
|
||||
open(json_file,
|
||||
'r'),
|
||||
object_pairs_hook=dict_raise_error_on_duplicate_keys)
|
||||
else:
|
||||
self._param_dict = param_dict
|
||||
|
||||
try:
|
||||
self.global_rank = torch.distributed.get_rank()
|
||||
if mpu is None:
|
||||
|
@ -263,7 +302,14 @@ class DeepSpeedConfig(object):
|
|||
self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict)
|
||||
|
||||
self.allgather_size = get_allgather_size(param_dict)
|
||||
self.zero_enabled = get_zero_enabled(param_dict)
|
||||
|
||||
self.zero_config = DeepSpeedZeroConfig(param_dict)
|
||||
self.zero_optimization_stage = self.zero_config.stage
|
||||
self.zero_enabled = self.zero_optimization_stage > 0
|
||||
|
||||
self.activation_checkpointing_config = DeepSpeedActivationCheckpointingConfig(
|
||||
param_dict)
|
||||
|
||||
self.gradient_clipping = get_gradient_clipping(param_dict)
|
||||
self.fp16_enabled = get_fp16_enabled(param_dict)
|
||||
self.loss_scale = get_loss_scale(param_dict)
|
||||
|
@ -285,6 +331,7 @@ class DeepSpeedConfig(object):
|
|||
self.scheduler_params = get_scheduler_params(param_dict)
|
||||
|
||||
self.wall_clock_breakdown = get_wall_clock_breakdown(param_dict)
|
||||
self.memory_breakdown = get_memory_breakdown(param_dict)
|
||||
self.tensorboard_enabled = get_tensorboard_enabled(param_dict)
|
||||
self.tensorboard_output_path = get_tensorboard_output_path(param_dict)
|
||||
self.tensorboard_job_name = get_tensorboard_job_name(param_dict)
|
||||
|
@ -305,8 +352,8 @@ class DeepSpeedConfig(object):
|
|||
f'Gradient accumulation steps: {grad_acc} has to be greater than 0'
|
||||
|
||||
assert train_batch == micro_batch * grad_acc * self.world_size, \
|
||||
(f'Check batch related parameters. Train_batch_size is not equal'
|
||||
'to micro_batch_per_gpu * gradient_acc_step * world_size'
|
||||
(f'Check batch related parameters. train_batch_size is not equal'
|
||||
' to micro_batch_per_gpu * gradient_acc_step * world_size'
|
||||
f'{train_batch} != {micro_batch} * {grad_acc} * {self.world_size}')
|
||||
|
||||
def _set_batch_related_parameters(self):
|
||||
|
@ -387,6 +434,7 @@ class DeepSpeedConfig(object):
|
|||
def _do_error_check(self):
|
||||
if self.zero_enabled:
|
||||
assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled"
|
||||
assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION)
|
||||
|
||||
assert self.train_micro_batch_size_per_gpu, "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU)
|
||||
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
"""
|
||||
Copyright (c) Microsoft Corporation
|
||||
Licensed under the MIT license.
|
||||
"""
|
||||
"""
|
||||
Collection of DeepSpeed configuration utilities
|
||||
"""
|
||||
|
||||
|
||||
def get_scalar_param(param_dict, param_name, param_default_value):
|
||||
if param_name in param_dict.keys():
|
||||
return param_dict[param_name]
|
||||
else:
|
||||
return param_default_value
|
||||
|
||||
|
||||
def dict_raise_error_on_duplicate_keys(ordered_pairs):
|
||||
"""Reject duplicate keys."""
|
||||
d = {}
|
||||
for k, v in ordered_pairs:
|
||||
if k in d:
|
||||
raise ValueError("Duplicate key in DeepSpeed config: %r" % (k, ))
|
||||
else:
|
||||
d[k] = v
|
||||
return d
|
|
@ -15,7 +15,7 @@ ROUTE_ENCODE = "encode"
|
|||
# Batch size
|
||||
#############################################
|
||||
TRAIN_BATCH_SIZE = "train_batch_size"
|
||||
TRAIN_BATCH_SIZE_DEFAULT = 1
|
||||
TRAIN_BATCH_SIZE_DEFAULT = None
|
||||
|
||||
#############################################
|
||||
# Optimizer and lr scheduler
|
||||
|
@ -133,14 +133,27 @@ GRADIENT_CLIPPING_DEFAULT = 0.
|
|||
# ZeRO optimization
|
||||
#########################################
|
||||
# ZeRO optimization. By default, this optimization is not enabled.
|
||||
# Users can configure in ds_config.json as below example:
|
||||
# Users have to configure the desired optimization (0 means disabled) in params.json as below example:
|
||||
ZERO_FORMAT = '''
|
||||
ZeRO optimization should be enabled as:
|
||||
"zero_optimization": true,
|
||||
"zero_all_gather_size": 200
|
||||
"session_params": {
|
||||
"zero_optimization": [0|1|2],
|
||||
"zero_all_gather_size": 200
|
||||
}
|
||||
'''
|
||||
|
||||
ZERO_OPTIMIZATION = 'zero_optimization'
|
||||
ZERO_OPTIMIZATION_DEFAULT = False
|
||||
ZERO_OPTIMIZATION_DEFAULT = 0
|
||||
ZERO_OPTIMIZATION_OPTIMIZER_STATES = 1
|
||||
ZERO_OPTIMIZATION_GRADIENTS = 2
|
||||
ZERO_OPTIMIZATION_WEIGHTS = 3
|
||||
MAX_STAGE_ZERO_OPTIMIZATION = ZERO_OPTIMIZATION_GRADIENTS
|
||||
|
||||
ZERO_REDUCE_SCATTER = "zero_reduce_scatter"
|
||||
ZERO_REDUCE_SCATTER_DEFAULT = True
|
||||
|
||||
ZERO_MAX_ELEMENTS_PER_COMM = "zero_max_elements_per_comm"
|
||||
ZERO_MAX_ELEMENTS_PER_COMM_DEFAULT = 5e8
|
||||
|
||||
ALLGATHER_SIZE = 'allgather_size'
|
||||
ALLGATHER_SIZE_DEFAULT = 500000000
|
||||
|
@ -217,6 +230,9 @@ Wall block breakdown should be enabled as:
|
|||
WALL_CLOCK_BREAKDOWN = 'wall_clock_breakdown'
|
||||
WALL_CLOCK_BREAKDOWN_DEFAULT = False
|
||||
|
||||
MEMORY_BREAKDOWN = 'memory_breakdown'
|
||||
MEMORY_BREAKDOWN_DEFAULT = False
|
||||
|
||||
#########################################
|
||||
# Tensorboard
|
||||
#########################################
|
||||
|
|
|
@ -8,11 +8,14 @@ import os
|
|||
import warnings
|
||||
import torch.distributed as dist
|
||||
from torch.nn.modules import Module
|
||||
from torch.distributed.distributed_c10d import _get_global_rank
|
||||
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from deepspeed.pt.deepspeed_timer import ThroughputTimer, SynchronizedWallClockTimer
|
||||
from deepspeed.pt.deepspeed_zero_optimizer import FP16_DeepSpeedZeroOptimizer
|
||||
from deepspeed.pt.zero_optimizer_stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
|
||||
import deepspeed.pt.deepspeed_checkpointing as deepspeed_activation_checkpointing
|
||||
|
||||
from deepspeed.pt.fp16_optimizer import FP16_Optimizer
|
||||
from deepspeed.pt.fp16_unfused_optimizer import FP16_UnfusedOptimizer
|
||||
|
@ -21,8 +24,10 @@ from deepspeed.pt.deepspeed_config import DeepSpeedConfig, \
|
|||
ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_OPTIMIZERS
|
||||
|
||||
from deepspeed.pt.deepspeed_dataloader import DeepSpeedDataLoader
|
||||
from deepspeed.pt.deepspeed_constants import ROUTE_TRAIN, ROUTE_PREDICT, \
|
||||
ROUTE_EVAL, TORCH_DISTRIBUTED_DEFAULT_PORT
|
||||
from deepspeed.pt.deepspeed_constants import \
|
||||
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
|
||||
TORCH_DISTRIBUTED_DEFAULT_PORT, \
|
||||
ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS
|
||||
|
||||
import deepspeed.pt.deepspeed_lr_schedules as lr_schedules
|
||||
from deepspeed.pt.deepspeed_csr_tensor import CSRTensor
|
||||
|
@ -96,7 +101,8 @@ class DeepSpeedLight(Module):
|
|||
lr_scheduler=None,
|
||||
mpu=None,
|
||||
dist_init_required=None,
|
||||
collate_fn=None):
|
||||
collate_fn=None,
|
||||
config_params=None):
|
||||
super(DeepSpeedLight, self).__init__()
|
||||
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
|
@ -116,6 +122,7 @@ class DeepSpeedLight(Module):
|
|||
self.gradient_predivide_factor = 1.0
|
||||
self.gradient_average = True
|
||||
self.warn_unscaled_loss = True
|
||||
self.config_params = config_params
|
||||
|
||||
if dist_init_required is None:
|
||||
dist_init_required = not dist.is_initialized()
|
||||
|
@ -146,6 +153,9 @@ class DeepSpeedLight(Module):
|
|||
# Configure distributed model
|
||||
self._configure_distributed_model(model)
|
||||
|
||||
# Configure wall clock timer
|
||||
self.timers = SynchronizedWallClockTimer()
|
||||
|
||||
# Throughput timer
|
||||
self.tput_timer = ThroughputTimer(
|
||||
batch_size=self.train_micro_batch_size_per_gpu(),
|
||||
|
@ -163,9 +173,6 @@ class DeepSpeedLight(Module):
|
|||
self._configure_lr_scheduler(lr_scheduler)
|
||||
self._report_progress(0)
|
||||
|
||||
# Configure wall clock timer
|
||||
self.timers = SynchronizedWallClockTimer()
|
||||
|
||||
# Bookkeeping for csr support
|
||||
self.csr_tensor_module_names = set()
|
||||
if self.sparse_gradients_enabled():
|
||||
|
@ -245,6 +252,9 @@ class DeepSpeedLight(Module):
|
|||
def wall_clock_breakdown(self):
|
||||
return self._config.wall_clock_breakdown
|
||||
|
||||
def memory_breakdown(self):
|
||||
return self._config.memory_breakdown
|
||||
|
||||
def sparse_gradients_enabled(self):
|
||||
return self._config.sparse_gradients_enabled
|
||||
|
||||
|
@ -275,6 +285,30 @@ class DeepSpeedLight(Module):
|
|||
def zero_allow_untested_optimizer(self):
|
||||
return self._config.zero_allow_untested_optimizer
|
||||
|
||||
def zero_reduce_scatter(self):
|
||||
return self._config.zero_config.reduce_scatter
|
||||
|
||||
def zero_overlap_comm(self):
|
||||
return self._config.zero_config.overlap_comm
|
||||
|
||||
def zero_max_elements_per_comm(self):
|
||||
return self._config.zero_max_elements_per_comm
|
||||
|
||||
def zero_optimization_stage(self):
|
||||
return self._config.zero_optimization_stage
|
||||
|
||||
def zero_reduce_bucket_size(self):
|
||||
return self._config.zero_config.reduce_bucket_size
|
||||
|
||||
def zero_allgather_bucket_size(self):
|
||||
return self._config.zero_config.allgather_bucket_size
|
||||
|
||||
def zero_optimization_partition_gradients(self):
|
||||
return self.zero_optimization_stage() >= ZERO_OPTIMIZATION_GRADIENTS
|
||||
|
||||
def zero_contiguous_gradients(self):
|
||||
return self._config.zero_config.contiguous_gradients
|
||||
|
||||
def allgather_size(self):
|
||||
return self._config.allgather_size
|
||||
|
||||
|
@ -296,8 +330,8 @@ class DeepSpeedLight(Module):
|
|||
def steps_per_print(self):
|
||||
return self._config.steps_per_print
|
||||
|
||||
def disable_allgather(self):
|
||||
return self._config.disable_allgather
|
||||
def zero_allgather_partitions(self):
|
||||
return self._config.zero_config.allgather_partitions
|
||||
|
||||
def dump_state(self):
|
||||
return self._config.dump_state
|
||||
|
@ -375,7 +409,9 @@ class DeepSpeedLight(Module):
|
|||
# Configure based on command line arguments
|
||||
def _configure_with_arguments(self, args, mpu):
|
||||
self.local_rank = args.local_rank if hasattr(args, 'local_rank') else 0
|
||||
self._config = DeepSpeedConfig(args.deepspeed_config, mpu)
|
||||
self._config = DeepSpeedConfig(args.deepspeed_config,
|
||||
mpu,
|
||||
param_dict=self.config_params)
|
||||
|
||||
# Validate command line arguments
|
||||
def _do_args_sanity_check(self, args):
|
||||
|
@ -390,11 +426,12 @@ class DeepSpeedLight(Module):
|
|||
assert hasattr(args, 'local_rank') and type(args.local_rank) == int, \
|
||||
'DeepSpeed requires integer command line parameter --local_rank'
|
||||
|
||||
assert hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None, \
|
||||
'DeepSpeed requires --deepspeed_config to specify configuration file'
|
||||
if self.config_params is None:
|
||||
assert hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None, \
|
||||
'DeepSpeed requires --deepspeed_config to specify configuration file'
|
||||
|
||||
assert os.path.isfile(args.deepspeed_config), \
|
||||
'DeepSpeed configuration file: {} is not an existing file'.format(args.deepspeed_config)
|
||||
assert os.path.isfile(args.deepspeed_config), \
|
||||
'DeepSpeed configuration file: {} is not an existing file'.format(args.deepspeed_config)
|
||||
|
||||
def _is_supported_optimizer(self, optimizer_name):
|
||||
return optimizer_name in DEEPSPEED_OPTIMIZERS or \
|
||||
|
@ -424,7 +461,8 @@ class DeepSpeedLight(Module):
|
|||
else:
|
||||
self.data_parallel_group = self.mpu.get_data_parallel_group()
|
||||
self.dp_world_size = self.mpu.get_data_parallel_world_size()
|
||||
src_rank = self.mpu.get_model_parallel_rank()
|
||||
src_rank = _get_global_rank(self.mpu.get_data_parallel_group(), 0)
|
||||
print(f"global src_rank={src_rank}")
|
||||
for p in self.module.parameters():
|
||||
if torch.is_tensor(p):
|
||||
dist.broadcast(p, src_rank, group=self.data_parallel_group)
|
||||
|
@ -518,17 +556,42 @@ class DeepSpeedLight(Module):
|
|||
return optimizer
|
||||
|
||||
def _configure_zero_optimizer(self, optimizer):
|
||||
logging.info('Creating fp16 zero optimizer')
|
||||
optimizer = FP16_DeepSpeedZeroOptimizer(
|
||||
optimizer,
|
||||
static_loss_scale=self.loss_scale(),
|
||||
dynamic_loss_scale=self.dynamic_loss_scale(),
|
||||
dynamic_loss_args=self.dynamic_loss_scale_args(),
|
||||
dp_process_group=self.data_parallel_group,
|
||||
clip_grad=self.gradient_clipping(),
|
||||
all_gather_partitions=not self.disable_allgather(),
|
||||
allgather_size=self.allgather_size(),
|
||||
mpu=self.mpu)
|
||||
zero_stage = self.zero_optimization_stage()
|
||||
logging.info('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage))
|
||||
|
||||
if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
|
||||
assert self.zero_reduce_scatter(), 'Stage 1 only supports reduce scatter mode'
|
||||
logging.info('Creating fp16 ZeRO Optimizer Stage 1')
|
||||
optimizer = FP16_DeepSpeedZeroOptimizer_Stage1(
|
||||
optimizer,
|
||||
static_loss_scale=self.loss_scale(),
|
||||
dynamic_loss_scale=self.dynamic_loss_scale(),
|
||||
dynamic_loss_args=self.dynamic_loss_scale_args(),
|
||||
clip_grad=self.gradient_clipping(),
|
||||
all_gather_partitions=self.zero_allgather_partitions(),
|
||||
allgather_size=self.zero_allgather_bucket_size(),
|
||||
max_elements_per_comm=self.zero_reduce_bucket_size(),
|
||||
dp_process_group=self.data_parallel_group,
|
||||
mpu=self.mpu)
|
||||
elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS:
|
||||
assert self.gradient_accumulation_steps() == 1, "ZeRO stage 2 does not support gradient accumulation, if you need gradient accumulation please use stage 1"
|
||||
optimizer = FP16_DeepSpeedZeroOptimizer(
|
||||
optimizer,
|
||||
timers=self.timers,
|
||||
static_loss_scale=self.loss_scale(),
|
||||
dynamic_loss_scale=self.dynamic_loss_scale(),
|
||||
dynamic_loss_args=self.dynamic_loss_scale_args(),
|
||||
clip_grad=self.gradient_clipping(),
|
||||
contiguous_gradients=self.zero_contiguous_gradients(),
|
||||
reduce_bucket_size=self.zero_reduce_bucket_size(),
|
||||
allgather_bucket_size=self.zero_allgather_bucket_size(),
|
||||
dp_process_group=self.data_parallel_group,
|
||||
reduce_scatter=self.zero_reduce_scatter(),
|
||||
overlap_comm=self.zero_overlap_comm(),
|
||||
mpu=self.mpu)
|
||||
else:
|
||||
raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage))
|
||||
logging.info('Creating fp16 zero stage {} optimizer'.format(zero_stage))
|
||||
|
||||
return optimizer
|
||||
|
||||
|
@ -624,7 +687,16 @@ class DeepSpeedLight(Module):
|
|||
|
||||
def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
|
||||
if self.is_gradient_accumulation_boundary():
|
||||
self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)
|
||||
if self.zero_optimization_stage() == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
|
||||
assert self.zero_reduce_scatter()
|
||||
self.optimizer.reduce_scatter_gradients(
|
||||
postscale_gradients=self.postscale_gradients(),
|
||||
gradient_predivide_factor=self.gradient_predivide_factor,
|
||||
gradient_average=self.gradient_average)
|
||||
elif self.zero_optimization_partition_gradients():
|
||||
self.optimizer.overlapping_partition_gradients_reduce_epilogue()
|
||||
else:
|
||||
self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)
|
||||
|
||||
def backward(self, loss, allreduce_gradients=True):
|
||||
r"""Execute backward pass on the loss
|
||||
|
@ -636,7 +708,7 @@ class DeepSpeedLight(Module):
|
|||
|
||||
# scale loss w.r.t. gradient accumulation if needed
|
||||
if self.gradient_accumulation_steps() > 1:
|
||||
loss = self._scale_loss(loss)
|
||||
loss = self._scale_loss(loss.float())
|
||||
|
||||
# Log training Loss
|
||||
if self.tensorboard_enabled():
|
||||
|
@ -765,27 +837,28 @@ class DeepSpeedLight(Module):
|
|||
'backward_inner_microstep',
|
||||
'backward_allreduce_microstep',
|
||||
'step_microstep'
|
||||
])
|
||||
# Log timing
|
||||
if self.tensorboard_enabled():
|
||||
if self.is_gradient_accumulation_boundary():
|
||||
if self.global_rank == 0:
|
||||
self.summary_events = [(f'Train/Samples/elapsed_time_ms_forward', self.timers('forward').elapsed(reset=False) * 1000.0, self.sample_count), \
|
||||
(f'Train/Samples/elapsed_time_ms_backward', self.timers('backward').elapsed(reset=False) * 1000.0, self.sample_count), \
|
||||
(f'Train/Samples/elapsed_time_ms_backward_inner', self.timers('backward_inner').elapsed(reset=False) * 1000.0, self.sample_count), \
|
||||
(f'Train/Samples/elapsed_time_ms_backward_allreduce', self.timers('backward_allreduce').elapsed(reset=False) * 1000.0, self.sample_count), \
|
||||
(f'Train/Samples/elapsed_time_ms_step', self.timers('step').elapsed(reset=False) * 1000.0, self.sample_count)
|
||||
]
|
||||
for event in self.summary_events: # write_summary_events
|
||||
self.summary_writer.add_scalar(event[0], event[1], event[2])
|
||||
self.summary_writer.flush()
|
||||
self.timers.log([
|
||||
'forward',
|
||||
'backward',
|
||||
'backward_inner',
|
||||
'backward_allreduce',
|
||||
'step'
|
||||
])
|
||||
],
|
||||
memory_breakdown=self.memory_breakdown())
|
||||
|
||||
if self.is_gradient_accumulation_boundary():
|
||||
if self.tensorboard_enabled() and torch.distributed.get_rank(
|
||||
) == 0: # this is done before the log because log resets timers
|
||||
self.summary_events = [(f'Train/elapsed_time_ms_forward', self.timers('forward').elapsed(reset=False) * 1000.0, self.sample_count), \
|
||||
(f'Train/elapsed_time_ms_backward', self.timers('backward').elapsed(reset=False) * 1000.0, self.sample_count), \
|
||||
(f'Train/elapsed_time_ms_backward_inner', self.timers('backward_inner').elapsed(reset=False) * 1000.0, self.sample_count), \
|
||||
(f'Train/elapsed_time_ms_backward_allreduce', self.timers('backward_allreduce').elapsed(reset=False) * 1000.0, self.sample_count), \
|
||||
(f'Train/elapsed_time_ms_step', self.timers('step').elapsed(reset=False) * 1000.0, self.sample_count)
|
||||
]
|
||||
for event in self.summary_events: # write_summary_events
|
||||
self.summary_writer.add_scalar(event[0], event[1], event[2])
|
||||
self.summary_writer.flush()
|
||||
self.timers.log([
|
||||
'forward',
|
||||
'backward',
|
||||
'backward_inner',
|
||||
'backward_allreduce',
|
||||
'step'
|
||||
])
|
||||
|
||||
self.micro_steps += 1
|
||||
|
||||
|
@ -971,19 +1044,30 @@ class DeepSpeedLight(Module):
|
|||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
def load_checkpoint(self, load_dir, tag, load_optimizer_states=True):
|
||||
def load_checkpoint(self,
|
||||
load_dir,
|
||||
tag,
|
||||
load_module_strict=True,
|
||||
load_optimizer_states=True,
|
||||
load_lr_scheduler_states=True):
|
||||
r"""Load training checkpoint
|
||||
|
||||
Arguments:
|
||||
load_dir: Required. Directory to load the checkpoint from
|
||||
tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step.
|
||||
load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match.
|
||||
load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance
|
||||
load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
|
||||
Return:
|
||||
load_path: Path of the loaded checkpoint. None if loading the checkpoint failed
|
||||
client_state: State dictionary used for loading required training states in the client code.
|
||||
"""
|
||||
|
||||
load_path, client_states = self._load_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)
|
||||
load_path, client_states = self._load_checkpoint(load_dir,
|
||||
tag,
|
||||
load_module_strict=load_module_strict,
|
||||
load_optimizer_states=load_optimizer_states,
|
||||
load_lr_scheduler_states=load_lr_scheduler_states)
|
||||
|
||||
if self.zero_optimization() and load_path is not None:
|
||||
self._load_zero_checkpoint(load_dir,
|
||||
|
@ -992,7 +1076,12 @@ class DeepSpeedLight(Module):
|
|||
|
||||
return load_path, client_states
|
||||
|
||||
def _load_checkpoint(self, load_dir, tag, load_optimizer_states=True):
|
||||
def _load_checkpoint(self,
|
||||
load_dir,
|
||||
tag,
|
||||
load_module_strict=True,
|
||||
load_optimizer_states=True,
|
||||
load_lr_scheduler_states=True):
|
||||
|
||||
load_path = self._get_ckpt_name(load_dir, tag)
|
||||
|
||||
|
@ -1005,12 +1094,13 @@ class DeepSpeedLight(Module):
|
|||
logging.info('Loading checkpoint: {}'.format(load_path))
|
||||
checkpoint = torch.load(load_path, map_location=lambda storage, loc: storage)
|
||||
|
||||
self.load_module_state_dict(checkpoint['module'])
|
||||
self.load_module_state_dict(state_dict=checkpoint['module'],
|
||||
strict=load_module_strict)
|
||||
if not self.zero_optimization():
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer'],
|
||||
load_optimizer_states=load_optimizer_states)
|
||||
|
||||
if self.lr_scheduler is not None:
|
||||
if load_lr_scheduler_states and self.lr_scheduler is not None:
|
||||
self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
||||
|
||||
self.csr_tensor_module_names = checkpoint['csr_tensor_module_names']
|
||||
|
@ -1019,6 +1109,7 @@ class DeepSpeedLight(Module):
|
|||
deepspeed_states = [
|
||||
'module',
|
||||
'optimizer',
|
||||
'lr_scheduler',
|
||||
'csr_tensor_module_names',
|
||||
'skipped_steps',
|
||||
'global_steps'
|
||||
|
@ -1058,19 +1149,15 @@ class DeepSpeedLight(Module):
|
|||
#There seems to be issue creating them in parallel
|
||||
self._create_checkpoint_files(save_dir, tag)
|
||||
|
||||
try:
|
||||
if self.save_non_zero_checkpoint:
|
||||
self._save_checkpoint(save_dir, tag, client_state=client_state)
|
||||
if self.save_non_zero_checkpoint:
|
||||
self._save_checkpoint(save_dir, tag, client_state=client_state)
|
||||
|
||||
if self.save_zero_checkpoint:
|
||||
self._save_zero_checkpoint(save_dir, tag)
|
||||
|
||||
if self.save_zero_checkpoint:
|
||||
self._save_zero_checkpoint(save_dir, tag)
|
||||
except:
|
||||
logging.error(f'Failed Saving model checkpoint to {save_dir} with tag {tag}')
|
||||
return False
|
||||
return True
|
||||
|
||||
def _create_checkpoint_files(self, save_dir, tag):
|
||||
|
||||
#checkpoint files are created sequentially
|
||||
for rank in range(self.world_size):
|
||||
if rank == self.global_rank:
|
||||
|
@ -1114,14 +1201,8 @@ class DeepSpeedLight(Module):
|
|||
torch.save(state, save_path)
|
||||
|
||||
def _save_zero_checkpoint(self, save_path, tag):
|
||||
try:
|
||||
zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag)
|
||||
#self._ensure_directory_exists(zero_checkpoint_name)
|
||||
|
||||
except:
|
||||
logging.error(
|
||||
f'Failed Saving Zero model checkpoint to {save_path} with tag {tag}')
|
||||
|
||||
zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag)
|
||||
#self._ensure_directory_exists(zero_checkpoint_name)
|
||||
zero_sd = {'optimizer_state_dict': self.optimizer.state_dict()}
|
||||
torch.save(zero_sd, zero_checkpoint_name)
|
||||
logging.info('zero checkpoint saved {}'.format(zero_checkpoint_name))
|
||||
|
|
|
@ -69,13 +69,27 @@ class SynchronizedWallClockTimer:
|
|||
self.timers[name] = self.Timer(name)
|
||||
return self.timers[name]
|
||||
|
||||
def log(self, names, normalizer=1.0, reset=True):
|
||||
@staticmethod
|
||||
def memory_usage():
|
||||
alloc = "mem_allocated: {:.4f} GB".format(torch.cuda.memory_allocated() /
|
||||
(1024 * 1024 * 1024))
|
||||
max_alloc = "max_mem_allocated: {:.4f} GB".format(
|
||||
torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024))
|
||||
cache = "cache_allocated: {:.4f} GB".format(torch.cuda.memory_cached() /
|
||||
(1024 * 1024 * 1024))
|
||||
max_cache = "max_cache_allocated: {:.4f} GB".format(
|
||||
torch.cuda.max_memory_cached() / (1024 * 1024 * 1024))
|
||||
return " | {} | {} | {} | {}".format(alloc, max_alloc, cache, max_cache)
|
||||
|
||||
def log(self, names, normalizer=1.0, reset=True, memory_breakdown=False):
|
||||
"""Log a group of timers."""
|
||||
assert normalizer > 0.0
|
||||
string = 'time (ms)'
|
||||
for name in names:
|
||||
elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer
|
||||
string += ' | {}: {:.2f}'.format(name, elapsed_time)
|
||||
if memory_breakdown:
|
||||
string += self.memory_usage()
|
||||
print_rank_0(string)
|
||||
|
||||
|
||||
|
|
|
@ -12,9 +12,10 @@ from torch._six import inf
|
|||
|
||||
class CheckOverflow(object):
|
||||
'''Checks for overflow in gradient across parallel process'''
|
||||
def __init__(self, param_groups=None, mpu=None):
|
||||
def __init__(self, param_groups=None, mpu=None, zero_reduce_scatter=False):
|
||||
self.mpu = mpu
|
||||
self.params = [] if param_groups else None
|
||||
self.zero_reduce_scatter = zero_reduce_scatter
|
||||
if param_groups:
|
||||
for group in param_groups:
|
||||
for param in group:
|
||||
|
@ -54,8 +55,8 @@ class CheckOverflow(object):
|
|||
|
||||
# `params` is a list / generator of torch.Variable
|
||||
def has_overflow_serial(self, params):
|
||||
for p in params:
|
||||
if p.grad is not None and self._has_inf_or_nan(p.grad.data):
|
||||
for i, p in enumerate(params):
|
||||
if p.grad is not None and self._has_inf_or_nan(p.grad.data, i):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -67,7 +68,11 @@ class CheckOverflow(object):
|
|||
#torch.distributed.all_reduce(overflow_gpu,
|
||||
# op=torch.distributed.ReduceOp.MAX,
|
||||
# group=mpu.get_model_parallel_group())
|
||||
if self.mpu is not None:
|
||||
if self.zero_reduce_scatter:
|
||||
torch.distributed.all_reduce(overflow_gpu,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=torch.distributed.group.WORLD)
|
||||
elif self.mpu is not None:
|
||||
torch.distributed.all_reduce(overflow_gpu,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=self.mpu.get_model_parallel_group())
|
||||
|
@ -76,7 +81,7 @@ class CheckOverflow(object):
|
|||
|
||||
# `x` is a torch.Tensor
|
||||
@staticmethod
|
||||
def _has_inf_or_nan(x):
|
||||
def _has_inf_or_nan(x, i):
|
||||
try:
|
||||
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
|
||||
# Pytorch's .sum() creates a one-element tensor of the same type as x
|
||||
|
@ -93,10 +98,25 @@ class CheckOverflow(object):
|
|||
return True
|
||||
else:
|
||||
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
|
||||
_handle_overflow(cpu_sum, x, i)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _handle_overflow(cpu_sum, x, i):
|
||||
import math
|
||||
rank = torch.distributed.get_rank()
|
||||
if rank == 0:
|
||||
t_i = -1
|
||||
for v_i, v in enumerate(x.data.contiguous().view(-1)):
|
||||
if not math.isfinite(float(v)):
|
||||
t_i = v_i
|
||||
break
|
||||
print(
|
||||
f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}"
|
||||
)
|
||||
|
||||
|
||||
def get_grad_norm(parameters, norm_type=2, mpu=None):
|
||||
"""Clips gradient norm of an iterable of parameters.
|
||||
|
||||
|
@ -221,3 +241,33 @@ def get_weight_norm(parameters, norm_type=2, mpu=None):
|
|||
total_norm = -1
|
||||
|
||||
return total_norm
|
||||
|
||||
|
||||
def is_model_parallel_parameter(p):
|
||||
return hasattr(p, 'model_parallel') and p.model_parallel
|
||||
|
||||
|
||||
def see_memory_usage(message):
|
||||
return
|
||||
if torch.distributed.is_initialized() and not torch.distributed.get_rank() == 0:
|
||||
return
|
||||
|
||||
# Print message except when distributed but not rank 0
|
||||
print(message, flush=True)
|
||||
print("Memory Allocated ",
|
||||
torch.cuda.memory_allocated() / (1024 * 1024 * 1024),
|
||||
"GigaBytes",
|
||||
flush=True)
|
||||
print("Max Memory Allocated ",
|
||||
torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),
|
||||
"GigaBytes",
|
||||
flush=True)
|
||||
print("Cache Allocated ",
|
||||
torch.cuda.memory_cached() / (1024 * 1024 * 1024),
|
||||
"GigaBytes",
|
||||
flush=True)
|
||||
print("Max cache Allocated ",
|
||||
torch.cuda.max_memory_cached() / (1024 * 1024 * 1024),
|
||||
"GigaBytes",
|
||||
flush=True)
|
||||
print(" ", flush=True)
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
"""
|
||||
Copyright (c) Microsoft Corporation
|
||||
Licensed under the MIT license.
|
||||
"""
|
||||
|
||||
import logging
|
||||
#from deepspeed.pt.deepspeed_constants import *
|
||||
from deepspeed.pt.deepspeed_config_utils import get_scalar_param
|
||||
|
||||
#########################################
|
||||
# ZeRO optimization
|
||||
#########################################
|
||||
# ZeRO optimization. By default, this optimization is not enabled.
|
||||
# Users have to configure the desired optimization (0 means disabled) in params.json as below example:
|
||||
ZERO_FORMAT = '''
|
||||
ZeRO optimization should be enabled as:
|
||||
"session_params": {
|
||||
"zero_optimization": {
|
||||
"stage": [0|1|2],
|
||||
"allgather_partitions": [true|false],
|
||||
"allgather_bucket_size": 500000000,
|
||||
"reduce_scatter": [true|false],
|
||||
"contiguous_gradients" : [true|false]
|
||||
"overlap_comm": [true|false],
|
||||
"reduce_bucket_size": 500000000
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
ZERO_OPTIMIZATION = 'zero_optimization'
|
||||
ZERO_OPTIMIZATION_DISABLED = 0
|
||||
ZERO_OPTIMIZATION_OPTIMIZER_STATES = 1
|
||||
ZERO_OPTIMIZATION_GRADIENTS = 2
|
||||
ZERO_OPTIMIZATION_WEIGHTS = 3
|
||||
MAX_STAGE_ZERO_OPTIMIZATION = ZERO_OPTIMIZATION_GRADIENTS
|
||||
|
||||
ZERO_OPTIMIZATION_STAGE = 'stage'
|
||||
ZERO_OPTIMIZATION_STAGE_1 = 'stage_1'
|
||||
ZERO_OPTIMIZATION_STAGE_2 = 'stage_2'
|
||||
ZERO_OPTIMIZATION_STAGE_3 = 'stage_3'
|
||||
|
||||
ZERO_OPTIMIZATION_STAGE_DEFAULT = ZERO_OPTIMIZATION_DISABLED
|
||||
|
||||
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS = 'allgather_partitions'
|
||||
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT = True
|
||||
|
||||
ZERO_OPTIMIZATION_REDUCE_SCATTER = 'reduce_scatter'
|
||||
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT = True
|
||||
|
||||
ZERO_OPTIMIZATION_OVERLAP_COMM = 'overlap_comm'
|
||||
ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT = False
|
||||
|
||||
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS = 'contiguous_gradients'
|
||||
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT = True
|
||||
|
||||
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE = 'reduce_bucket_size'
|
||||
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT = 500000000
|
||||
|
||||
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE = 'allgather_bucket_size'
|
||||
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT = 500000000
|
||||
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEPRECATED = 'allgather_size'
|
||||
|
||||
ZERO_OPTIMIZATION_DEFAULT = {
|
||||
ZERO_OPTIMIZATION_STAGE: ZERO_OPTIMIZATION_STAGE_DEFAULT,
|
||||
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS:
|
||||
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT,
|
||||
ZERO_OPTIMIZATION_REDUCE_SCATTER: ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
|
||||
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE: ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT,
|
||||
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS:
|
||||
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT,
|
||||
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE:
|
||||
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT
|
||||
}
|
||||
|
||||
|
||||
class DeepSpeedZeroConfig(object):
|
||||
def __init__(self, param_dict):
|
||||
super(DeepSpeedZeroConfig, self).__init__()
|
||||
|
||||
self.stage = None
|
||||
self.contiguous_gradients = None
|
||||
self.reduce_scatter = None
|
||||
self.reduce_bucket_size = None
|
||||
self.allgather_partitions = None
|
||||
self.allgather_bucket_size = None
|
||||
self.overlap_comm = None
|
||||
|
||||
if ZERO_OPTIMIZATION in param_dict.keys():
|
||||
zero_config_dict = param_dict[ZERO_OPTIMIZATION]
|
||||
if type(zero_config_dict) is bool:
|
||||
zero_config_dict = self.read_zero_config_deprecated(param_dict)
|
||||
else:
|
||||
zero_config_dict = ZERO_OPTIMIZATION_DEFAULT
|
||||
|
||||
self._initialize(zero_config_dict)
|
||||
|
||||
def read_zero_config_deprecated(self, param_dict):
|
||||
zero_config_dict = {}
|
||||
zero_config_dict[
|
||||
ZERO_OPTIMIZATION_STAGE] = 1 if param_dict[ZERO_OPTIMIZATION] else 0
|
||||
if zero_config_dict[ZERO_OPTIMIZATION_STAGE] > 0:
|
||||
zero_config_dict[ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE] = get_scalar_param(
|
||||
param_dict,
|
||||
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEPRECATED,
|
||||
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT)
|
||||
|
||||
logging.warning(
|
||||
'DeepSpeedConfig: this format of ZeRO optimization setup is deprecated. Please use the following format: {}'
|
||||
.format(ZERO_FORMAT))
|
||||
return zero_config_dict
|
||||
|
||||
"""
|
||||
For json serialization
|
||||
"""
|
||||
|
||||
def repr(self):
|
||||
return self.__dict__
|
||||
|
||||
def _initialize(self, zero_config_dict):
|
||||
self.stage = get_scalar_param(zero_config_dict,
|
||||
ZERO_OPTIMIZATION_STAGE,
|
||||
ZERO_OPTIMIZATION_STAGE_DEFAULT)
|
||||
|
||||
self.contiguous_gradients = get_scalar_param(
|
||||
zero_config_dict,
|
||||
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS,
|
||||
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT)
|
||||
|
||||
self.reduce_bucket_size = get_scalar_param(
|
||||
zero_config_dict,
|
||||
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE,
|
||||
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT)
|
||||
|
||||
self.reduce_scatter = get_scalar_param(zero_config_dict,
|
||||
ZERO_OPTIMIZATION_REDUCE_SCATTER,
|
||||
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT)
|
||||
|
||||
self.overlap_comm = get_scalar_param(zero_config_dict,
|
||||
ZERO_OPTIMIZATION_OVERLAP_COMM,
|
||||
ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT)
|
||||
|
||||
self.allgather_partitions = get_scalar_param(
|
||||
zero_config_dict,
|
||||
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS,
|
||||
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT)
|
||||
|
||||
self.allgather_bucket_size = get_scalar_param(
|
||||
zero_config_dict,
|
||||
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE,
|
||||
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT)
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,803 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from collections import defaultdict
|
||||
|
||||
from deepspeed.pt.zero_utils import _initialize_parameter_parallel_groups, \
|
||||
pprint
|
||||
from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler
|
||||
from deepspeed.pt.deepspeed_utils import get_grad_norm, CheckOverflow
|
||||
|
||||
|
||||
def flatten_dense_tensors_sub_partition_aligned(tensor_list,
|
||||
dp,
|
||||
max_elements_per_comm,
|
||||
pg):
|
||||
num_elements = 0
|
||||
for tensor in tensor_list:
|
||||
num_elements = num_elements + tensor.numel()
|
||||
|
||||
pprint("Total number of elements in model: {}, max elements per com: {}".format(
|
||||
num_elements,
|
||||
max_elements_per_comm))
|
||||
|
||||
max_elements_per_comm = min(max_elements_per_comm, num_elements)
|
||||
sub_partition_size = int(max_elements_per_comm // dp)
|
||||
|
||||
alignment = sub_partition_size
|
||||
|
||||
# if alignment == 0:
|
||||
# # number of elements not divisible by dp, outside range and small model must pad with zeroes
|
||||
# pad_tensor = torch.zeros(max_elements_per_comm,
|
||||
# device=tensor_list[0].device,
|
||||
# dtype=tensor_list[0].dtype)
|
||||
# return _flatten_dense_tensors(pad_tensor)
|
||||
|
||||
remaining = int(num_elements % alignment)
|
||||
|
||||
# ensure we have equal sized sub-partitions
|
||||
elements_to_add = 0
|
||||
if remaining:
|
||||
elements_to_add = alignment - remaining
|
||||
# adding padded tensor later after we check comm alignment
|
||||
pprint("adding pad tensor for alignment, {} + {}->{}".format(
|
||||
num_elements,
|
||||
elements_to_add,
|
||||
num_elements + elements_to_add))
|
||||
#num_elements = num_elements + elements_to_add
|
||||
else:
|
||||
padded_tensor_list = tensor_list
|
||||
|
||||
num_partitions = int((num_elements + elements_to_add) // sub_partition_size)
|
||||
assert (num_elements + elements_to_add) % sub_partition_size == 0, "num elements should be " \
|
||||
"aligned by sub partition " \
|
||||
"size"
|
||||
num_comm_intervals = int(num_partitions // dp)
|
||||
partition_remaining = int(num_partitions % dp)
|
||||
pprint("num_comm_intervals={}, partition_remaining={}".format(
|
||||
num_comm_intervals,
|
||||
partition_remaining))
|
||||
if partition_remaining != 0:
|
||||
pprint("adding pad tensor and/or extra sub partition")
|
||||
# add pad tensor for alignment of comm interval, this overrules previous possibly sub-partition alignment
|
||||
num_comm_intervals += 1
|
||||
aligned_comm_elements = num_comm_intervals * sub_partition_size * dp
|
||||
elements_to_add = aligned_comm_elements - num_elements
|
||||
|
||||
pad_tensor = torch.zeros(elements_to_add,
|
||||
device=tensor_list[0].device,
|
||||
dtype=tensor_list[0].dtype)
|
||||
padded_tensor_list = tensor_list + [pad_tensor]
|
||||
pprint("adding pad tensor and/or extra sub partition, {} + {}->{}".format(
|
||||
num_elements,
|
||||
elements_to_add,
|
||||
num_elements + elements_to_add))
|
||||
num_elements += elements_to_add
|
||||
elif elements_to_add > 0:
|
||||
# add pad tensor for just alignment of sub-partition
|
||||
pad_tensor = torch.zeros(elements_to_add,
|
||||
device=tensor_list[0].device,
|
||||
dtype=tensor_list[0].dtype)
|
||||
padded_tensor_list = tensor_list + [pad_tensor]
|
||||
num_elements += elements_to_add
|
||||
|
||||
if pg is None or dist.get_rank(group=pg) == 0:
|
||||
print("Number of Elements (w. padding) is ", num_elements)
|
||||
|
||||
padded_num_elems = 0
|
||||
for p in padded_tensor_list:
|
||||
padded_num_elems += p.numel()
|
||||
assert num_elements == padded_num_elems, "{} != {}, rank={}".format(num_elements, padded_num_elems, dist.get_rank())
|
||||
|
||||
return _flatten_dense_tensors(padded_tensor_list)
|
||||
|
||||
|
||||
def _single_range_check(current_index, start_index, end_index, tensor_size):
|
||||
offset = 0
|
||||
if (current_index >= start_index) and (current_index < end_index):
|
||||
# Fully inside bounds
|
||||
return True, offset
|
||||
elif (start_index > current_index) and (start_index < (current_index + tensor_size)):
|
||||
# Partially contained, compute offset
|
||||
offset = start_index - current_index
|
||||
return True, offset
|
||||
else:
|
||||
return False, offset
|
||||
|
||||
|
||||
def _range_check(current_index, element_intervals, tensor_size):
|
||||
results = []
|
||||
for comm_idx, interval in enumerate(element_intervals):
|
||||
start_index, end_index = interval
|
||||
contained, offset = _single_range_check(current_index, start_index, end_index, tensor_size)
|
||||
if contained:
|
||||
results.append((contained, offset, comm_idx))
|
||||
if len(results) == 0:
|
||||
return [(False, 0, -1)]
|
||||
return results
|
||||
|
||||
|
||||
class FP16_DeepSpeedZeroOptimizer_Stage1(object):
|
||||
"""
|
||||
FP16_DeepSpeedZeroOptimizer_Stage1 designed to reduce the memory footprint
|
||||
required for training large deep learning models.
|
||||
|
||||
For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
|
||||
https://arxiv.org/abs/1910.02054
|
||||
|
||||
This version aligns with stage-1 in the paper above.
|
||||
"""
|
||||
def __init__(self,
|
||||
init_optimizer,
|
||||
static_loss_scale=1.0,
|
||||
dynamic_loss_scale=False,
|
||||
dynamic_loss_args=None,
|
||||
verbose=True,
|
||||
dp_process_group=None,
|
||||
partition_size=None,
|
||||
mpu=None,
|
||||
all_gather_partitions=True,
|
||||
allgather_size=500000000,
|
||||
clip_grad=0.0,
|
||||
max_elements_per_comm=5e8):
|
||||
|
||||
if dp_process_group is not None and partition_size is not None:
|
||||
raise ValueError("Cannot specify both dp_process_group "
|
||||
"and partition size")
|
||||
|
||||
if dp_process_group is None:
|
||||
dp_process_group = _initialize_parameter_parallel_groups(partition_size)
|
||||
|
||||
if not torch.cuda.is_available:
|
||||
raise SystemError("Cannot use fp16 without CUDA.")
|
||||
self.optimizer = init_optimizer
|
||||
|
||||
self.verbose = verbose
|
||||
self.dp_process_group = dp_process_group
|
||||
|
||||
# TODO: automatically turn off if #params > some_limit
|
||||
self.all_gather_partitions = all_gather_partitions
|
||||
self.allgather_size = allgather_size
|
||||
|
||||
self.max_elements_per_comm = max_elements_per_comm
|
||||
print("max_elements_per_comm={}".format(max_elements_per_comm))
|
||||
|
||||
# param flattened by groups
|
||||
self.fp16_groups = []
|
||||
self.fp16_groups_flat = []
|
||||
|
||||
# Setup bookkeeping data structures depending on partitioning type
|
||||
|
||||
# parallel_sub_partitioned_fp16_groups[group-idx] -> [comm-ids] -> [rank-ids]
|
||||
self.parallel_sub_partitioned_fp16_groups = []
|
||||
# same underlying data as above but viewed as: [groups] -> [rank-ids] -> [comm-ids]
|
||||
self.parallel_comm_sub_partitioned_fp16_groups = []
|
||||
|
||||
# 32-bit sub-partitions of the parallel partitioned parameters
|
||||
# that this process will update
|
||||
self.local_sub_partitions_of_fp32_groups = []
|
||||
|
||||
# param partition info
|
||||
|
||||
# parameters in each group that will not be updated by this process directly
|
||||
self.params_not_local = []
|
||||
|
||||
# parameters that will be updated by this process directly
|
||||
self.params_in_rank_sub_partitions = []
|
||||
|
||||
# parameter offsets for parameters in sub-partitions. Parameter
|
||||
# boundaries may not align with sub-partition boundaries
|
||||
# so we need to keep track of the offsets
|
||||
self.params_in_rank_sub_partitions_offsets = []
|
||||
|
||||
# number of elements per sub-partition in each group
|
||||
self.sub_partition_sizes = []
|
||||
|
||||
# number of communication intervals for each group
|
||||
self.num_comm_intervals_per_group = []
|
||||
|
||||
local_rank = dist.get_rank(group=self.dp_process_group)
|
||||
|
||||
# loop to deal with groups
|
||||
for i, param_group in enumerate(self.optimizer.param_groups):
|
||||
# push this group to list before modify
|
||||
self.fp16_groups.append(param_group['params'])
|
||||
|
||||
# flattens all tensors into single 1d tensor aligned with sub-partition size for later dividing
|
||||
# RS: create aligned sub-partitions
|
||||
self.fp16_groups_flat.append(
|
||||
flatten_dense_tensors_sub_partition_aligned(
|
||||
tensor_list=self.fp16_groups[i],
|
||||
dp=dist.get_world_size(group=self.dp_process_group),
|
||||
max_elements_per_comm=self.max_elements_per_comm,
|
||||
pg=self.dp_process_group))
|
||||
|
||||
# TODO: I don't think this does anything?
|
||||
# set model fp16 weight to slices of flattened buffer
|
||||
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
|
||||
self.fp16_groups[i])
|
||||
for p, q in zip(self.fp16_groups[i], updated_params):
|
||||
p.data = q.data
|
||||
|
||||
# divide the flat weights into near equal partition equal to the data parallel degree
|
||||
# each process will compute on a different part of the partition
|
||||
# RS: split into two layer list -> [comm-id] -> [sub-partitions per rank]
|
||||
comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \
|
||||
self.get_data_parallel_sub_partitions(
|
||||
tensor=self.fp16_groups_flat[i],
|
||||
max_elements_per_comm=self.max_elements_per_comm,
|
||||
world_size=dist.get_world_size(
|
||||
group=self.dp_process_group),
|
||||
dp_process_group=self.dp_process_group
|
||||
)
|
||||
self.parallel_comm_sub_partitioned_fp16_groups.append(
|
||||
comm_partitions) # comm -> rank
|
||||
self.parallel_sub_partitioned_fp16_groups.append(
|
||||
dp_sub_partitions) # rank -> comm
|
||||
self.sub_partition_sizes.append(sub_partition_size)
|
||||
self.num_comm_intervals_per_group.append(num_comm_intervals)
|
||||
# data_parallel_partitions = self.get_data_parallel_partitions(self.fp16_groups_flat[i])
|
||||
# self.parallel_partitioned_fp16_groups.append(data_parallel_partitions)
|
||||
|
||||
# a partition of the fp32 master weights that will be updated by this process
|
||||
# RS: store/detach/cast our local sub-partitions
|
||||
local_sub_partitions = []
|
||||
for sub_partition in self.parallel_sub_partitioned_fp16_groups[i][
|
||||
local_rank]:
|
||||
fp32_sub_partition = sub_partition.clone().float().detach()
|
||||
fp32_sub_partition.requires_grad = True
|
||||
local_sub_partitions.append(fp32_sub_partition)
|
||||
self.local_sub_partitions_of_fp32_groups.append(local_sub_partitions)
|
||||
|
||||
# modify optimizer of have flat master weight
|
||||
# self.single_partition_of_fp32_groups[i].requires_grad = True # keep this in case internal optimizer uses it
|
||||
param_group['params'] = self.local_sub_partitions_of_fp32_groups[i]
|
||||
|
||||
# RS: divide up the sub-partitions and keep track of offsets for each param
|
||||
# partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size(group=self.dp_process_group)
|
||||
params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, \
|
||||
params_not_local = self.get_all_sub_partition_info(
|
||||
tensor_list=self.fp16_groups[i],
|
||||
all_element_intervals=element_intervals,
|
||||
local_rank=local_rank,
|
||||
world_size=dist.get_world_size(group=self.dp_process_group)
|
||||
)
|
||||
|
||||
self.params_in_rank_sub_partitions.append(params_in_rank_sub_partition)
|
||||
self.params_not_local.append(params_not_local)
|
||||
self.params_in_rank_sub_partitions_offsets.append(
|
||||
params_in_rank_sub_partitions_offsets)
|
||||
|
||||
# we may have a way of fusing dynamic scale. Do not support for now
|
||||
if dynamic_loss_scale:
|
||||
if dynamic_loss_args is None:
|
||||
self.loss_scaler = DynamicLossScaler()
|
||||
else:
|
||||
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
|
||||
|
||||
self.dynamic_loss_scale = True
|
||||
|
||||
else:
|
||||
self.dynamic_loss_scale = False
|
||||
self.loss_scaler = LossScaler(scale=static_loss_scale)
|
||||
self.cur_iter = 0
|
||||
|
||||
self.mpu = mpu
|
||||
self.clip_grad = clip_grad
|
||||
|
||||
self.overflow = False
|
||||
self.overflow_checker = CheckOverflow(self.fp16_groups,
|
||||
mpu=self.mpu,
|
||||
zero_reduce_scatter=True)
|
||||
|
||||
@staticmethod
|
||||
def get_data_parallel_sub_partitions(tensor,
|
||||
max_elements_per_comm,
|
||||
world_size,
|
||||
dp_process_group=None):
|
||||
total_num_elements = tensor.numel()
|
||||
|
||||
# if total elements is less than our max, revert to splitting into dp partitions
|
||||
max_elements_per_comm = min(total_num_elements, max_elements_per_comm)
|
||||
sub_partition_size = int(max_elements_per_comm // world_size)
|
||||
|
||||
# Ensure partition alignment was done correctly
|
||||
num_sub_partitions = int(total_num_elements // sub_partition_size)
|
||||
assert total_num_elements % sub_partition_size == 0, "{} % {} != 0".format(total_num_elements, sub_partition_size)
|
||||
|
||||
# Ensure comm interval alignment was done correctly.
|
||||
num_comm_intervals = int(num_sub_partitions // world_size)
|
||||
assert num_sub_partitions % world_size == 0, "{} % {} != 0".format(num_sub_partitions, world_size)
|
||||
|
||||
if not dist.is_initialized() or dist.get_rank(group=dp_process_group) == 0:
|
||||
print("**** partition info:")
|
||||
print("\t total_num_elements=", total_num_elements)
|
||||
print("\t world_size=", world_size)
|
||||
print("\t max_elements_per_comm=", max_elements_per_comm)
|
||||
print("\t sub_partition_size=", sub_partition_size)
|
||||
print("\t num_sub_partitions=", num_sub_partitions)
|
||||
print("\t num_comm_intervals=", num_comm_intervals)
|
||||
print("****")
|
||||
|
||||
# [comm_id] -> [rank]
|
||||
comm_partitions = []
|
||||
for _ in range(num_comm_intervals):
|
||||
comm_partitions.append([])
|
||||
|
||||
start = 0
|
||||
comm_id = 0
|
||||
element_intervals = defaultdict(
|
||||
list) # [rank] -> [(start,end), (start,end), ...]
|
||||
for idx in range(num_sub_partitions):
|
||||
rank_id = idx % world_size
|
||||
sub_partition = tensor.narrow(0, start, sub_partition_size)
|
||||
element_intervals[rank_id].append((start, start + sub_partition_size))
|
||||
comm_partitions[comm_id].append(sub_partition)
|
||||
start = start + sub_partition_size
|
||||
if rank_id == (world_size - 1):
|
||||
comm_id += 1
|
||||
|
||||
# [rank] -> [comm_id]
|
||||
sub_partitions = []
|
||||
for _ in range(world_size):
|
||||
sub_partitions.append([])
|
||||
for comm_id, partitions in enumerate(comm_partitions):
|
||||
for rank_id, partition in enumerate(partitions):
|
||||
sub_partitions[rank_id].append(partition)
|
||||
|
||||
return comm_partitions, sub_partitions, element_intervals, sub_partition_size, num_comm_intervals
|
||||
|
||||
@staticmethod
|
||||
def get_all_sub_partition_info(tensor_list,
|
||||
all_element_intervals,
|
||||
local_rank,
|
||||
world_size):
|
||||
params_not_local = []
|
||||
|
||||
# [rank] -> [comm-id] -> [param/offset]
|
||||
params_in_rank_sub_partition = []
|
||||
params_in_rank_sub_partitions_offsets = []
|
||||
|
||||
for rank in range(world_size):
|
||||
params_in_local_sub_partition = []
|
||||
local_sub_partition_offsets = []
|
||||
comm_tensor_list = []
|
||||
comm_offset_list = []
|
||||
current_index = 0
|
||||
prev_comm_idx = 0
|
||||
for iii, tensor in enumerate(tensor_list):
|
||||
tensor_size = tensor.numel()
|
||||
#if local_rank == 0:
|
||||
# #print("rank={}, current_index={}, tensor_size={}, tensor-idx={}".format(rank,
|
||||
# current_index, tensor_size, iii))
|
||||
results_list = _range_check(current_index,
|
||||
all_element_intervals[rank],
|
||||
tensor_size)
|
||||
for contained, offset, comm_idx in results_list:
|
||||
#if local_rank == 0:
|
||||
# print("rank={}, contained={}, offset={}, comm_idx={}".format(rank, contained,
|
||||
# offset, comm_idx))
|
||||
if contained:
|
||||
if prev_comm_idx != comm_idx:
|
||||
params_in_local_sub_partition.append(comm_tensor_list)
|
||||
comm_tensor_list = []
|
||||
local_sub_partition_offsets.append(comm_offset_list)
|
||||
comm_offset_list = []
|
||||
comm_tensor_list.append(tensor)
|
||||
comm_offset_list.append(offset)
|
||||
prev_comm_idx = comm_idx
|
||||
elif rank == local_rank:
|
||||
params_not_local.append(tensor)
|
||||
|
||||
current_index = current_index + tensor_size
|
||||
|
||||
#assert len(comm_tensor_list) > 0
|
||||
#assert len(comm_offset_list) > 0
|
||||
params_in_local_sub_partition.append(comm_tensor_list)
|
||||
local_sub_partition_offsets.append(comm_offset_list)
|
||||
|
||||
params_in_rank_sub_partition.append(params_in_local_sub_partition)
|
||||
params_in_rank_sub_partitions_offsets.append(local_sub_partition_offsets)
|
||||
|
||||
return params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, params_not_local
|
||||
|
||||
@staticmethod
|
||||
def get_flat_sub_partitions(comm_tensor_list,
|
||||
comm_param_offsets,
|
||||
sub_partition_size,
|
||||
dtype,
|
||||
num_comm_intervals=None,
|
||||
default_device=None,
|
||||
return_partition_params=False):
|
||||
partition_params = []
|
||||
final_param_offsets = []
|
||||
flat_sub_partitions = []
|
||||
for tensor_list, param_offsets in zip(comm_tensor_list, comm_param_offsets):
|
||||
flat_tensor_list = []
|
||||
current_size = 0
|
||||
my_offsets = []
|
||||
my_params = []
|
||||
|
||||
if dtype is None:
|
||||
dtype = tensor_list[0].dtype
|
||||
|
||||
for i, tensor in enumerate(tensor_list):
|
||||
if tensor.grad is None:
|
||||
tensor.grad = torch.zeros(tensor.size(),
|
||||
dtype=tensor.dtype,
|
||||
device=tensor.device)
|
||||
param = tensor
|
||||
tensor = tensor.grad
|
||||
num_elements = tensor.numel()
|
||||
tensor_offset = 0
|
||||
|
||||
#we need to offset to get to the right element
|
||||
if i == 0 and param_offsets[i] > 0:
|
||||
tensor_offset = param_offsets[i]
|
||||
num_elements = num_elements - tensor_offset
|
||||
|
||||
# We don't need all elements of the tensor if this tensor is
|
||||
# larger than we have space for in our curr sub-partition
|
||||
if num_elements > (sub_partition_size - current_size):
|
||||
num_elements = sub_partition_size - current_size
|
||||
|
||||
#we need a narrow view of the tensor based on the tensor offset and number of elements that
|
||||
#we need from this tensor
|
||||
if tensor_offset > 0 or num_elements < tensor.numel():
|
||||
flat_tensor_list.append(tensor.contiguous().view(-1).narrow(
|
||||
0,
|
||||
int(tensor_offset),
|
||||
int(num_elements)).to(dtype))
|
||||
else:
|
||||
flat_tensor_list.append(tensor.to(dtype))
|
||||
my_params.append(param)
|
||||
|
||||
#remember offset into partition and #elems for this tensor
|
||||
my_offsets.append((current_size, num_elements))
|
||||
|
||||
current_size = current_size + num_elements
|
||||
|
||||
#this means its the last partition and does not align with the dp boundary. We need to pad before flattening
|
||||
if current_size < sub_partition_size:
|
||||
my_offsets.append((None, None))
|
||||
my_params.append(None)
|
||||
if len(tensor_list) == 0:
|
||||
assert default_device != None
|
||||
flat_tensor_list.append(
|
||||
torch.zeros(int(sub_partition_size - current_size),
|
||||
dtype=dtype,
|
||||
device=default_device))
|
||||
else:
|
||||
flat_tensor_list.append(
|
||||
torch.zeros(int(sub_partition_size - current_size),
|
||||
dtype=dtype,
|
||||
device=tensor_list[0].device))
|
||||
partition_params.append(my_params) #flat_tensor_list)
|
||||
final_param_offsets.append(my_offsets)
|
||||
assert len(flat_tensor_list) == len(my_offsets), "{} {}".format(len(flat_tensor_list), len(my_offsets))
|
||||
flat_sub_partitions.append(_flatten_dense_tensors(flat_tensor_list))
|
||||
if num_comm_intervals is not None and len(
|
||||
flat_sub_partitions) < num_comm_intervals:
|
||||
#print("padding w. sub partitions to ensure uniform communication")
|
||||
device = flat_sub_partitions[0].device
|
||||
for _ in range(num_comm_intervals - len(flat_sub_partitions)):
|
||||
flat_sub_partitions.append(
|
||||
torch.zeros(int(sub_partition_size),
|
||||
dtype=dtype,
|
||||
device=device))
|
||||
partition_params.append([None])
|
||||
final_param_offsets.append([(None, None)])
|
||||
|
||||
if return_partition_params:
|
||||
assert len(flat_sub_partitions) == len(partition_params)
|
||||
assert len(partition_params) == len(final_param_offsets), "{} {}".format(len(partition_params), len(final_param_offsets))
|
||||
return flat_sub_partitions, partition_params, final_param_offsets
|
||||
return flat_sub_partitions
|
||||
|
||||
def zero_grad(self, set_grads_to_None=True):
|
||||
"""
|
||||
Zero FP16 parameter grads.
|
||||
"""
|
||||
# FP32 grad should never exist.
|
||||
# For speed, set model fp16 grad to None by default
|
||||
for group in self.fp16_groups:
|
||||
for p in group:
|
||||
if set_grads_to_None:
|
||||
p.grad = None
|
||||
else:
|
||||
if p.grad is not None:
|
||||
p.grad.detach_()
|
||||
p.grad.zero_()
|
||||
|
||||
def free_grad_in_param_list(self, param_list):
|
||||
for p in param_list:
|
||||
if isinstance(p, list):
|
||||
for _p in p:
|
||||
_p.grad = None
|
||||
else:
|
||||
p.grad = None
|
||||
|
||||
def reduce_scatter_gradients(self,
|
||||
postscale_gradients,
|
||||
gradient_predivide_factor,
|
||||
gradient_average):
|
||||
world_size = dist.get_world_size(group=self.dp_process_group)
|
||||
local_rank = dist.get_rank(group=self.dp_process_group)
|
||||
|
||||
for i, group in enumerate(self.fp16_groups):
|
||||
partition_param_map = {}
|
||||
param_partition_map = {}
|
||||
my_params = set()
|
||||
|
||||
# [rank] -> [comm] -> partition
|
||||
num_comm_intervals = self.num_comm_intervals_per_group[i]
|
||||
all_sub_partitions = []
|
||||
for rank in range(world_size):
|
||||
# gsp is list of partitions indexed by comm_idx
|
||||
#FIXME: currently hardcoding fp16, should infer dtype
|
||||
grad_sub_partitions, partition_params, param_offsets = self.get_flat_sub_partitions(
|
||||
comm_tensor_list=self.params_in_rank_sub_partitions[i][rank],
|
||||
comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i][rank],
|
||||
sub_partition_size=self.sub_partition_sizes[i],
|
||||
dtype=torch.half, #self.params_in_rank_sub_partitions[i][rank][0][0].dtype,
|
||||
num_comm_intervals=self.num_comm_intervals_per_group[i],
|
||||
default_device='cuda', #self.params_in_rank_sub_partitions[i][rank][0][0].device,
|
||||
return_partition_params=True)
|
||||
all_sub_partitions.append(grad_sub_partitions)
|
||||
|
||||
# create map from partition -> params in that partition
|
||||
for comm_idx, part in enumerate(grad_sub_partitions):
|
||||
partition_param_map[part] = (partition_params[comm_idx],
|
||||
param_offsets[comm_idx])
|
||||
|
||||
for comm_idx, params in enumerate(partition_params):
|
||||
for pidx, p in enumerate(params):
|
||||
# store the parameters we care about locally
|
||||
if rank == local_rank:
|
||||
my_params.add(p)
|
||||
# map from param -> partitions
|
||||
if p in param_partition_map:
|
||||
param_partition_map[p].append(grad_sub_partitions[comm_idx])
|
||||
else:
|
||||
param_partition_map[p] = [grad_sub_partitions[comm_idx]]
|
||||
|
||||
assert len(grad_sub_partitions) == num_comm_intervals
|
||||
|
||||
if not postscale_gradients:
|
||||
raise NotImplementedError("pre-scale_gradients is not implemented")
|
||||
|
||||
all_comm_partitions = []
|
||||
for comm_idx in range(num_comm_intervals):
|
||||
single_comm_all_partitions = []
|
||||
for rank in range(world_size):
|
||||
single_comm_all_partitions.append(all_sub_partitions[rank][comm_idx])
|
||||
dist.reduce_scatter(output=single_comm_all_partitions[local_rank],
|
||||
input_list=single_comm_all_partitions,
|
||||
group=self.dp_process_group)
|
||||
|
||||
if gradient_average:
|
||||
for partition in single_comm_all_partitions:
|
||||
partition.mul_(gradient_predivide_factor / world_size)
|
||||
|
||||
all_comm_partitions.append(single_comm_all_partitions)
|
||||
|
||||
for p in my_params:
|
||||
partitions = param_partition_map[p]
|
||||
parts = []
|
||||
for part in partitions:
|
||||
params, offsets = partition_param_map[part]
|
||||
found = False
|
||||
for p_idx, _p in enumerate(params):
|
||||
if p.__hash__() == _p.__hash__():
|
||||
found = True
|
||||
if offsets[p_idx][0] is not None:
|
||||
my_part = part.narrow(0,
|
||||
offsets[p_idx][0],
|
||||
offsets[p_idx][1])
|
||||
parts.append(my_part)
|
||||
assert found
|
||||
if p is not None:
|
||||
updated_grad = _unflatten_dense_tensors(torch.cat(parts), [p])
|
||||
p.grad.copy_(updated_grad[0])
|
||||
|
||||
def step(self, closure=None):
|
||||
# First compute norm for all group so we know if there is overflow
|
||||
|
||||
self.overflow = self.overflow_checker.check()
|
||||
|
||||
prev_scale = self.loss_scale
|
||||
self._update_scale(self.overflow)
|
||||
if self.overflow:
|
||||
self.zero_grad()
|
||||
if self.verbose:
|
||||
print("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
|
||||
"scale: {}, reducing to {}".format(prev_scale,
|
||||
self.loss_scale))
|
||||
return self.overflow
|
||||
|
||||
norm_groups = []
|
||||
local_sub_partitions_grad_groups = []
|
||||
|
||||
partition_id = dist.get_rank(group=self.dp_process_group)
|
||||
for i, group in enumerate(self.fp16_groups):
|
||||
|
||||
#TODO RS: update get grad norm to support sub partitions
|
||||
norm_groups.append(get_grad_norm(group, mpu=self.mpu))
|
||||
|
||||
#RS: update free grads w.r.t. sub partitions
|
||||
#free gradients for all the parameters that are not updated by this process
|
||||
self.free_grad_in_param_list(self.params_not_local[i])
|
||||
|
||||
#create flat gradients for parameters updated by this process
|
||||
#tensor_list, first_offset, partition_size, dtype
|
||||
#single_grad_partition = self.get_flat_partition(
|
||||
# tensor_list=self.params_in_partition[i],
|
||||
# first_offset=self.first_offset[i],
|
||||
# partition_size=self.partition_size[i],
|
||||
# dtype=self.single_partition_of_fp32_groups[i].dtype
|
||||
#)
|
||||
|
||||
#TODO RS: can we safely use dtype of the first sub-partition? i think so
|
||||
local_grad_sub_partitions = self.get_flat_sub_partitions(
|
||||
comm_tensor_list=self.params_in_rank_sub_partitions[i][partition_id],
|
||||
comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i]
|
||||
[partition_id],
|
||||
sub_partition_size=self.sub_partition_sizes[i],
|
||||
dtype=self.local_sub_partitions_of_fp32_groups[i][0].dtype,
|
||||
num_comm_intervals=self.num_comm_intervals_per_group[i],
|
||||
default_device=self.local_sub_partitions_of_fp32_groups[i][0].device)
|
||||
|
||||
#RS: update all our local params with sub-partition grads
|
||||
#print("self.local_sub_partitions_of_fp32_groups[i]={}, local_grad_sub_partitions={}".format(len(self.local_sub_partitions_of_fp32_groups[i]), len(local_grad_sub_partitions)))
|
||||
for idx, sub_partition_param in enumerate(self.local_sub_partitions_of_fp32_groups[i]):
|
||||
sub_partition_param.grad = local_grad_sub_partitions[idx]
|
||||
#self.single_partition_of_fp32_groups[i].grad = single_grad_partition
|
||||
|
||||
#RS: update free grads for sub-partitions
|
||||
#release all the gradient since we have already created a necessary copy in dp_grad_partition
|
||||
self.free_grad_in_param_list(
|
||||
self.params_in_rank_sub_partitions[i][partition_id])
|
||||
|
||||
local_sub_partitions_grad_groups.append(local_grad_sub_partitions)
|
||||
|
||||
#RS: update unscale/clip with sub partitions
|
||||
self.unscale_and_clip_grads(local_sub_partitions_grad_groups, norm_groups)
|
||||
|
||||
self.optimizer.step()
|
||||
|
||||
#RS: clear our sub partition grads
|
||||
#get rid of the fp32 gradients. Not needed anymore
|
||||
for group in self.local_sub_partitions_of_fp32_groups:
|
||||
for idx, sub_partition_param in enumerate(group):
|
||||
sub_partition_param.grad = None
|
||||
#group.grad = None
|
||||
|
||||
#NOTE RS: removed norm_groups outer loop from original code, i don't think it's needed
|
||||
#RS: copy all sub-partition fp32 data to fp16 sub partitions
|
||||
# copy fp32 param data to fp16 partitions w.r.t. our local rank
|
||||
for fp16_all_sub_partitions, fp32_local_sub_partitions in zip(self.parallel_sub_partitioned_fp16_groups, self.local_sub_partitions_of_fp32_groups):
|
||||
for local_sub_partition_param_fp16, local_sub_partition_param_fp32 in zip(fp16_all_sub_partitions[partition_id], fp32_local_sub_partitions):
|
||||
local_sub_partition_param_fp16.data.copy_(
|
||||
local_sub_partition_param_fp32.data)
|
||||
|
||||
#RS: all_gather/broadcast sub-partitions in separate comm calls
|
||||
#gather the updated weights from everyone
|
||||
for fp16_all_sub_partitions in self.parallel_comm_sub_partitioned_fp16_groups:
|
||||
for comm_id, sub_partitions in enumerate(fp16_all_sub_partitions):
|
||||
dist.all_gather(sub_partitions,
|
||||
sub_partitions[partition_id],
|
||||
group=self.dp_process_group)
|
||||
|
||||
# TODO: we probably don't need this? just to be safe
|
||||
for i in range(len(norm_groups)):
|
||||
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
|
||||
self.fp16_groups[i])
|
||||
for p, q in zip(self.fp16_groups[i], updated_params):
|
||||
p.data = q.data
|
||||
|
||||
return self.overflow
|
||||
|
||||
def unscale_and_clip_grads(self, grad_groups_flat, norm_groups):
|
||||
total_norm = 0.0
|
||||
for norm in norm_groups:
|
||||
total_norm += norm**2.0
|
||||
total_norm = math.sqrt(total_norm)
|
||||
|
||||
# compute combined scale factor for this group
|
||||
combined_scale = self.loss_scale
|
||||
if self.clip_grad > 0.:
|
||||
# norm is in fact norm*scale
|
||||
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
|
||||
if clip > 1:
|
||||
combined_scale = clip * self.loss_scale
|
||||
|
||||
for grad in grad_groups_flat:
|
||||
if isinstance(grad, list):
|
||||
sub_partitions = grad
|
||||
for g in sub_partitions:
|
||||
g.data.mul_(1. / combined_scale)
|
||||
else:
|
||||
grad.data.mul_(1. / combined_scale)
|
||||
|
||||
def backward(self, loss, retain_graph=False):
|
||||
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
|
||||
|
||||
def _update_scale(self, has_overflow=False):
|
||||
self.loss_scaler.update_scale(has_overflow)
|
||||
|
||||
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
|
||||
def _get_state(self):
|
||||
return self.optimizer.state
|
||||
|
||||
def _set_state(self, value):
|
||||
self.optimizer.state = value
|
||||
|
||||
state = property(_get_state, _set_state)
|
||||
|
||||
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
|
||||
# (for example, to adjust the learning rate)
|
||||
def _get_param_groups(self):
|
||||
return self.optimizer.param_groups
|
||||
|
||||
def _set_param_groups(self, value):
|
||||
self.optimizer.param_groups = value
|
||||
|
||||
param_groups = property(_get_param_groups, _set_param_groups)
|
||||
|
||||
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
|
||||
def _get_loss_scale(self):
|
||||
return self.loss_scaler.loss_scale
|
||||
|
||||
def _set_loss_scale(self, value):
|
||||
self.loss_scaler.cur_scale = value
|
||||
|
||||
loss_scale = property(_get_loss_scale, _set_loss_scale)
|
||||
cur_scale = property(_get_loss_scale, _set_loss_scale)
|
||||
|
||||
def state_dict(self):
|
||||
"""
|
||||
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
|
||||
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
|
||||
of the contained Pytorch optimizer.
|
||||
Example::
|
||||
checkpoint = {}
|
||||
checkpoint['model'] = model.state_dict()
|
||||
checkpoint['optimizer'] = optimizer.state_dict()
|
||||
torch.save(checkpoint, "saved.pth")
|
||||
"""
|
||||
state_dict = {}
|
||||
state_dict['loss_scaler'] = self.loss_scaler
|
||||
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
|
||||
state_dict['overflow'] = self.overflow
|
||||
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
|
||||
state_dict[
|
||||
'local_sub_partitions_of_fp32_groups'] = self.local_sub_partitions_of_fp32_groups
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict, load_optimizer_states=True):
|
||||
"""
|
||||
Loads a state_dict created by an earlier call to state_dict().
|
||||
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
|
||||
whose parameters in turn came from ``model``, it is expected that the user
|
||||
will call ``model.load_state_dict()`` before
|
||||
``fp16_optimizer_instance.load_state_dict()`` is called.
|
||||
Example::
|
||||
model = torch.nn.Linear(D_in, D_out).cuda().half()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
|
||||
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
|
||||
...
|
||||
checkpoint = torch.load("saved.pth")
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
"""
|
||||
# I think it should actually be ok to reload the optimizer before the model.
|
||||
self.loss_scaler = state_dict['loss_scaler']
|
||||
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
|
||||
self.overflow = state_dict['overflow']
|
||||
if load_optimizer_states:
|
||||
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
|
||||
|
||||
for curr_group, saved_group in zip(self.local_sub_partitions_of_fp32_groups, state_dict['local_sub_partitions_of_fp32_groups']):
|
||||
for curr_param, saved_param in zip(curr_group, saved_group):
|
||||
curr_param.data.copy_(saved_param.data)
|
|
@ -0,0 +1,24 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
|
||||
data_parallel_size = int(dist.get_world_size())
|
||||
if parameter_parallel_size is None:
|
||||
parameter_parallel_size = int(data_parallel_size)
|
||||
print(data_parallel_size, parameter_parallel_size)
|
||||
assert data_parallel_size % parameter_parallel_size == 0, \
|
||||
'world size should be divisible by parameter parallel size'
|
||||
rank = dist.get_rank()
|
||||
my_group = None
|
||||
for i in range(dist.get_world_size() // parameter_parallel_size):
|
||||
ranks = range(i * parameter_parallel_size, (i + 1) * parameter_parallel_size)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if rank in ranks:
|
||||
my_group = group
|
||||
return my_group
|
||||
|
||||
|
||||
def pprint(msg):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
print(msg)
|
|
@ -29,6 +29,14 @@ collections:
|
|||
tutorials:
|
||||
output: true
|
||||
permalink: /:collection/:path/
|
||||
order:
|
||||
- getting-started.md
|
||||
- azure.md
|
||||
- cifar-10.md
|
||||
- bert-pretraining.md
|
||||
- megatron.md
|
||||
- 1Cycle.md
|
||||
- lrrt.md
|
||||
|
||||
defaults:
|
||||
- scope:
|
||||
|
|
|
@ -18,7 +18,7 @@ lnav:
|
|||
children:
|
||||
- title: "Installation"
|
||||
url: /getting-started/#installation
|
||||
- title: "Writing Models"
|
||||
- title: "Writing models"
|
||||
url: /getting-started/#writing-deepspeed-models
|
||||
- title: "Training"
|
||||
url: /getting-started/#training
|
||||
|
@ -37,19 +37,25 @@ lnav:
|
|||
url: /docs/config-json/#communication-options
|
||||
- title: "FP16"
|
||||
url: /docs/config-json/#fp16-training-options
|
||||
- title: "ZeRO optimizations"
|
||||
url: /docs/config-json/#zero-optimizations-for-fp16-training
|
||||
- title: "Logging"
|
||||
url: /docs/config-json/#logging
|
||||
- title: "Activation checkpointing"
|
||||
url: /docs/config-json/#activation-checkpointing
|
||||
- title: "Tutorials"
|
||||
url: /tutorials/
|
||||
children:
|
||||
- title: "Getting Started on Azure"
|
||||
- title: "Getting started"
|
||||
url: /getting-started/
|
||||
- title: "Getting started on Azure"
|
||||
url: /tutorials/azure/
|
||||
- title: "CIFAR-10"
|
||||
url: /tutorials/cifar-10/
|
||||
- title: "Megatron-LM GPT2"
|
||||
url: /tutorials/megatron/
|
||||
- title: "BERT Pre-training"
|
||||
url: /tutorials/bert-pretraining/
|
||||
- title: "Megatron-LM GPT2"
|
||||
url: /tutorials/megatron/
|
||||
- title: "1-Cycle Schedule"
|
||||
url: /tutorials/1Cycle/
|
||||
- title: "Learning Rate Range Test"
|
||||
|
|
|
@ -12,12 +12,6 @@ layout: archive
|
|||
{% endif %}
|
||||
|
||||
|
||||
<h2>Features Coming Soon</h2>
|
||||
{% assign soon = posts | where: "sneak_preview", "true" %}
|
||||
{% for post in soon %}
|
||||
{% include archive-single.html %}
|
||||
{% endfor %}
|
||||
|
||||
<h2>{{ site.data.ui-text[site.locale].recent_posts | default: "Recent Posts" }}</h2>
|
||||
{% assign news = posts | where: "sneak_preview", "false" %}
|
||||
{% for post in news %}
|
||||
|
|
|
@ -102,15 +102,8 @@ Example of ***scheduler***
|
|||
| ------------------------------------------------------------ | ------- |
|
||||
| Enable sparse compression of [torch.nn.Embedding](https://pytorch.org/docs/stable/nn.html#torch.nn.Embedding) gradients. | `false` |
|
||||
|
||||
|
||||
### FP16 training options
|
||||
|
||||
***zero\_optimization***: [boolean]
|
||||
|
||||
| Description | Default |
|
||||
| ------------------------------------------------------------ | ------- |
|
||||
| Enable ZeRO memory optimization wrapper for FP16 Training. Currently compatible only with Adam optimizer. | `false` |
|
||||
|
||||
***fp16***: [dictionary]
|
||||
|
||||
| Description | Default |
|
||||
|
@ -172,6 +165,66 @@ Example of ***scheduler***
|
|||
| ----------------------------------- | ------- |
|
||||
| Enable gradient clipping with value | `0` |
|
||||
|
||||
|
||||
|
||||
### ZeRO Optimizations for FP16 Training
|
||||
|
||||
Enabling and configure ZeRO memory optimizations
|
||||
```json
|
||||
"zero_optimization": {
|
||||
"stage": [0|1|2],
|
||||
"allgather_partitions": [true|false],
|
||||
"allgather_bucket_size": 500000000,
|
||||
"reduce_scatter": [true|false],
|
||||
"reduce_bucket_size": 500000000,
|
||||
"contiguous_gradients" : [true|false]
|
||||
}
|
||||
```
|
||||
|
||||
***zero\_optimization***: [dictionary]
|
||||
|
||||
| Description | Default |
|
||||
| ------------------------------------------------------------ | ------- |
|
||||
| Enable ZeRO memory optimization wrapper for FP16 Training. Currently compatible only with Adam optimizer. | `false` |
|
||||
|
||||
***stage***: [integer]
|
||||
|
||||
| Description | Default |
|
||||
| ------------------------------------------------------------ | ------- |
|
||||
| Chooses different stages of ZeRO Optimizer. Stage 0, 1, and 2 refer to disabled, optimizer state partitioning, and optimizer+gradient state partitiong, respectively. | `0` |
|
||||
|
||||
***allgather_partitions***: [boolean]
|
||||
|
||||
| Description | Default |
|
||||
| ------------------------------------------------------------ | ------- |
|
||||
| Chooses between allgather collective or a series of broadcast collectives to gather updated parameters from all the GPUs at the end of each step | `true` |
|
||||
|
||||
***allgather_bucket_size***: [boolean]
|
||||
|
||||
| Description | Default |
|
||||
| ------------------------------------------------------------ | ------- |
|
||||
| Number of elements allgathered at a time. Limits the memory required for the allgather for large model sizes | `500000000` |
|
||||
|
||||
***reduce_scatter***: [boolean]
|
||||
|
||||
| Description | Default |
|
||||
| ------------------------------------------------------------ | ------- |
|
||||
| Uses reduce or reduce scatter instead of allreduce to average gradients | `true` |
|
||||
|
||||
***reduce_bucket_size***: [boolean]
|
||||
|
||||
| Description | Default |
|
||||
| ------------------------------------------------------------ | ------- |
|
||||
| Number of elements reduced/allreduced at a time. Limits the memory required for the allgather for large model sizes | `500000000` |
|
||||
|
||||
***contiguous_gradients***: [boolean]
|
||||
|
||||
| Description | Default |
|
||||
| ------------------------------------------------------------ | ------- |
|
||||
| Copies the gradients to a contiguous buffer as they are produced. Avoids memory fragmentation during backward pass. Only useful when running very large models. | `False` |
|
||||
|
||||
|
||||
|
||||
### Logging
|
||||
|
||||
***steps\_per\_print***: [integer]
|
||||
|
@ -191,3 +244,52 @@ Example of ***scheduler***
|
|||
| Description | Default |
|
||||
| ------------------------------------------------------------ | ------- |
|
||||
| Print out state information of DeepSpeed object after initialization | `false` |
|
||||
|
||||
### Activation Checkpointing
|
||||
```json
|
||||
"activation_checkpointing": {
|
||||
"partition_activations": false,
|
||||
"cpu_checkpointing": false,
|
||||
"contiguous_memory_optimization": false,
|
||||
"number_checkpoints": null,
|
||||
"synchronize_checkpoint_boundary": false,
|
||||
"profile": false
|
||||
}
|
||||
```
|
||||
***partition\_activations***: [boolean]
|
||||
|
||||
| Description | Default |
|
||||
| ------------------------------------------------------------ | ------- |
|
||||
| Enables partition activation when used with model parallelism | `false` |
|
||||
|
||||
***cpu\_checkpointing***: [boolean]
|
||||
|
||||
| Description | Default |
|
||||
| ------------------------------------------------------------ | ------- |
|
||||
| Offloads partitioned activations to CPU if partition_activations is enabled| `false` |
|
||||
|
||||
|
||||
***contiguous\_memory\_optimization***: [boolean]
|
||||
|
||||
| Description | Default |
|
||||
| ------------------------------------------------------------ | ------- |
|
||||
| Copies partitioned activations so that they are contiguous in memory | `false` |
|
||||
|
||||
***number_checkpoints***: [integer]
|
||||
|
||||
| Description | Default |
|
||||
| ------------------------------------------------------------ | ------- |
|
||||
| Total number of activation checkpoints used to allocate memory buffer for contiguous_memoty_optimization | `None` |
|
||||
|
||||
***synchronize\_checkpoint\_boundary***: [boolean]
|
||||
|
||||
| Description | Default |
|
||||
| ------------------------------------------------------------ | ------- |
|
||||
| Inserts torch.cuda.synchronize() at each checkpoint boundary. | `false` |
|
||||
|
||||
|
||||
***profile***: [boolean]
|
||||
|
||||
| Description | Default |
|
||||
| ------------------------------------------------------------ | ------- |
|
||||
| Logs the forward and backward time for each checkpoint function | `false` |
|
||||
|
|
|
@ -57,19 +57,33 @@ DeepSpeed is fully compatible with [Megatron](https://github.com/NVIDIA/Megatron
|
|||
Please see the [Megatron-LM tutorial](/tutorials/megatron/) for details.
|
||||
|
||||
|
||||
|
||||
## Memory and Bandwidth Optimizations
|
||||
|
||||
### The Zero Redundancy Optimizer (ZeRO)
|
||||
[ZeRO](https://arxiv.org/abs/1910.02054) is at the heart of DeepSpeed and
|
||||
enables large model training at a scale that is simply not possible with model
|
||||
parallelism alone. When enabled, ZeRO allows training models with
|
||||
over 6 billion parameters without any model parallelism, and up to 100 billion
|
||||
parameter models with model parallelism on current generation hardware.
|
||||
## The Zero Redundancy Optimizer
|
||||
The Zero Redundancy Optimizer ([ZeRO](https://arxiv.org/abs/1910.02054)) is at
|
||||
the heart of DeepSpeed and enables large model training at a scale that is
|
||||
simply not possible with model parallelism alone. When enabled, ZeRO allows
|
||||
training models with over 13 billion parameters without any model parallelism,
|
||||
and up to 200 billion parameter models with model parallelism on current
|
||||
generation hardware.
|
||||
|
||||
For more details see the [ZeRO paper](https://arxiv.org/abs/1910.02054), [GPT
|
||||
tutorial](/tutorials/megatron/) on integration with
|
||||
DeepSpeed. Additional tutorials including *BERT Tutorial*: Coming Soon.
|
||||
DeepSpeed.
|
||||
|
||||
### Optimizer State and Gradient Partitioning
|
||||
Optimizer State and Gradient Partitioning in ZeRO reduces the memory consumption of the
|
||||
model states (optimizer states, gradients and parmaeters) by 8x compared to standard
|
||||
data parallelism by partitioning these states across data parallel process instead of
|
||||
replicating them.
|
||||
|
||||
### Activation Partitioning
|
||||
Activation Partitioning is a memory optimization in ZeRO that can reduce the memory
|
||||
consumed by activations during model parallel training (MP). In MP certain
|
||||
activations maybe required by all MP processes, resulting in a replication of
|
||||
activations across MP GPUs. Activation Partitioning stores these activations in a
|
||||
partitioned state once they are used for computation in the forward propagation. These
|
||||
activations are allgathered right before they are needed again during the backward propagation.
|
||||
By storing activations in a partitioned state, ZeRO in DeepSpeed can reduce the activation
|
||||
memory footprint proportional to the MP degree.
|
||||
|
||||
### Constant Buffer Optimization (CBO)
|
||||
CBO enables high network and memory throughput while restricting memory usage to a
|
||||
|
@ -80,6 +94,17 @@ unnecessary memory overhead. CBO in DeepSpeed fuses smaller operands into approx
|
|||
pre-defined sized buffer large enough to achieve great performance without the
|
||||
unnecessary memory overhead.
|
||||
|
||||
### Contiguous Memory Optimization (CMO)
|
||||
CMO reduces reduces memory fragmentation during training, preventing out of memory errors
|
||||
due to lack of contiguous memory. Memory fragmentation is a result of interleaving between
|
||||
short lived and long lived memory objects. During the forward propagation activation
|
||||
checkpoints are long lived but the activations that recomputed are short lived. Similarly,
|
||||
during the backward computation, the activation gradients are short lived while the parameter
|
||||
gradients are long lived. CMO transfers activation checkpoints and parameter gradients
|
||||
to contiguous buffers preventing memory fragmentation.
|
||||
|
||||
## Additional Memory and Bandwidth Optimizations
|
||||
|
||||
### Smart Gradient Accumulation
|
||||
Gradient accumulation allows running larger batch size with limited memory by breaking an
|
||||
effective batch into several sequential micro-batches, and averaging the parameter
|
||||
|
@ -90,6 +115,11 @@ averaged gradients for the effective batch across all GPUs. This strategy signif
|
|||
reduces the communication involved over the approach of averaging globally for each
|
||||
micro-batch, specially when the number of micro-batches per effective batch is large.
|
||||
|
||||
### Communication Overlapping
|
||||
During back propagation, DeepSpeed can overlap the communication required for averaging
|
||||
parameter gradients that have already been computed with the ongoing gradient computation.
|
||||
This computation communication overlap, allows DeepSpeed to achieve higher throughput even
|
||||
at modest batch sizes.
|
||||
|
||||
## Training Features
|
||||
|
||||
|
@ -100,12 +130,23 @@ The DeepSpeed core API consists of just a handful of methods:
|
|||
* argument parsing: `add_config_arguments`
|
||||
* checkpointing : `load_checkpoint` and `store_checkpoint`
|
||||
|
||||
DeepSpeed supports all the features described in this document, via the use of these API,
|
||||
DeepSpeed supports most of the features described in this document, via the use of these API,
|
||||
along with a `deepspeed_config` JSON file for enabling and disabling the features.
|
||||
Please see the [core API doc](https://deepspeed.readthedocs.io/) for more details.
|
||||
|
||||
### Activation Checkpointing API
|
||||
|
||||
DeepSpeed's Activation Checkpoinitng API supports activation checkpoint partitioning,
|
||||
cpu checkpoiniting, and contiguous memory optimizations, while also allowing layerwise
|
||||
profiling. Please see the [core API doc](https://deepspeed.readthedocs.io/) for more details.
|
||||
|
||||
|
||||
### Gradient Clipping
|
||||
```json
|
||||
{
|
||||
"gradient_clipping": 1.0
|
||||
}
|
||||
```
|
||||
DeepSpeed handles gradient clipping under the hood based on the max gradient norm
|
||||
specified by the user.
|
||||
Please see the [core API doc](https://deepspeed.readthedocs.io/) for more details.
|
||||
|
@ -136,8 +177,8 @@ DeepSpeed makes it easy to train with large batch sizes by enabling the LAMB Opt
|
|||
For more details on LAMB, see the [LAMB paper](https://arxiv.org/pdf/1904.00962.pdf).
|
||||
|
||||
### Memory-Efficient Training with ZeRO Optimizer
|
||||
DeepSpeed can train models up with up to 6 billion parameters without parallelism, and
|
||||
models with up to 100 billion parameters with 16-way model parallelism. This leap in
|
||||
DeepSpeed can train models up with up to 13 billion parameters without parallelism, and
|
||||
models with up to 200 billion parameters with 16-way model parallelism. This leap in
|
||||
model size is possible though the memory efficiency achieved via the ZeRO Optimizer. For
|
||||
more details see [ZeRO paper](https://arxiv.org/abs/1910.02054) .
|
||||
|
||||
|
@ -174,6 +215,10 @@ file.
|
|||
Please see the [core API doc](https://deepspeed.readthedocs.io/) for more details.
|
||||
```json
|
||||
{
|
||||
"wall_clock_breakdown": true
|
||||
"wall_clock_breakdown": true,
|
||||
|
||||
"activation_checkpointing": {
|
||||
"profile": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
|
|
@ -1,10 +0,0 @@
|
|||
---
|
||||
title: "ZeRO stage 2"
|
||||
sneak_preview: true
|
||||
excerpt: "Reduce memory footprint to enable training 10B models without model parallelism!"
|
||||
---
|
||||
* Reduce memory footprint of gradients
|
||||
* Train larger models: e.g., 10B parameters on 32GPUs without model parallelism
|
||||
* Train larger batch sizes
|
||||
|
||||
## Further updates coming soon!
|
|
@ -0,0 +1,21 @@
|
|||
---
|
||||
layout: single
|
||||
title: "DeepSpeed optimizes transformer kernels to achieve world's fastest BERT training record: 44 minutes on 1024 NVIDIA V100 GPUs"
|
||||
excerpt: ""
|
||||
categories: news
|
||||
new_post: true
|
||||
date: 2020-05-19 00:00:00
|
||||
---
|
||||
|
||||
|
||||
We introduce new technology to accelerate single GPU performance via
|
||||
kernel optimizations. These optimizations not only create a strong
|
||||
foundation for scaling out large models, but also improve the single GPU
|
||||
performance of highly tuned and moderately sized models like BERT by more
|
||||
than 30%, reaching a staggering performance of 66 teraflops per V100 GPU,
|
||||
which is 52% of the hardware peak. **Using these optimizations as the building
|
||||
block, DeepSpeed achieves the fastest BERT training record: 44 minutes on
|
||||
1,024 NVIDIA V100 GPUs**, compared with the best published result
|
||||
of 67 minutes on the same number and generation of GPUs.
|
||||
|
||||
**Code and tutorials are coming soon!**
|
|
@ -0,0 +1,22 @@
|
|||
---
|
||||
layout: single
|
||||
title: "ZeRO-2 empowers training models as large as 170 billion parameters up to 10x faster compared to state-of-the-art"
|
||||
excerpt: ""
|
||||
categories: news
|
||||
new_post: true
|
||||
date: 2020-05-19 01:00:00
|
||||
---
|
||||
|
||||
ZeRO-2 expands the scope of memory optimizations in the original ZeRO by
|
||||
tackling the full spectrum of memory consumption during training. More
|
||||
specifically, ZeRO-2 introduces new technology to reduce the memory footprint
|
||||
of gradients, activation memory, and fragmented memory, in addition to
|
||||
optimizer state memory optimization in the original ZeRO. Altogether, the
|
||||
memory savings empower DeepSpeed to improve the scale and speed of deep
|
||||
learning training by an order of magnitude. More concretely, ZeRO-2 allows
|
||||
training models as large as 170 billion parameters up to 10x faster compared
|
||||
to state of the art.
|
||||
|
||||
For more information on using ZeRO-2, see the [Megatron tutorial](/tutorials/megatron/).
|
||||
|
||||
For a technical deep dive, see our [technical report](https://arxiv.org/abs/1910.02054).
|
|
@ -2,6 +2,7 @@
|
|||
title: "Getting Started"
|
||||
permalink: /getting-started/
|
||||
excerpt: "First steps with DeepSpeed"
|
||||
date: 2020-05-15
|
||||
---
|
||||
|
||||
## Installation
|
||||
|
|
|
@ -320,6 +320,43 @@ and return the states for the client model.
|
|||
|
||||
```
|
||||
|
||||
### DeepSpeed Activation Checkpoints (Optional)
|
||||
|
||||
DeepSpeed can reduce the activation memory during model parallel training by partitioning activation checkpoints across model parallel GPUs, or offloading them to CPU. These optimization are optional, and can be skipped unless activation memory becomes a memory bottlenck. To enable partition activation, we use the `deepspeed.checkpointing` API to replace Megatron's activation checkpointing and random state tracker APIs. The replacement should happen before the first invocation of these APIs.
|
||||
|
||||
a) Replace in `pretrain_gpt.py` :
|
||||
|
||||
```python
|
||||
# Optional DeepSpeed Activation Checkpointing Features
|
||||
#
|
||||
if args.deepspeed and args.deepspeed_activation_checkpointing:
|
||||
set_deepspeed_activation_checkpointing(args)
|
||||
|
||||
def set_deepspeed_activation_checkpointing(args):
|
||||
|
||||
deepspeed.checkpointing.configure(mpu,
|
||||
deepspeed_config=args.deepspeed_config,
|
||||
partition_activation=True)
|
||||
|
||||
mpu.checkpoint = deepspeed.checkpointing.checkpoint
|
||||
mpu.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
|
||||
mpu.model_parallel_cuda_manual_seed =
|
||||
deepspeed.checkpointing.model_parallel_cuda_manual_seed
|
||||
```
|
||||
|
||||
b) Replace in `mpu/transformer.py`:
|
||||
|
||||
```python
|
||||
if deepspeed.checkpointing.is_configured():
|
||||
global get_cuda_rng_tracker, checkpoint
|
||||
get_cuda_rng_tracker = deepspeed.checkpoint.get_cuda_rng_tracker
|
||||
checkpoint = deepspeed.checkpointing.checkpoint
|
||||
|
||||
```
|
||||
|
||||
With these replacements, various DeepSpeed activation checkpointing optimizations such as activation partitioning, contiguous checkpointing, and CPU checkpointing, can be specified with either `deepspeed.checkpointing.configure` or in the `deepspeed_config` file.
|
||||
|
||||
|
||||
### Train scripts
|
||||
Assume webtext data was prepared in previous step, to start training
|
||||
Megatron-LM GPT2 model with DeepSpeed applied, execute the following command to
|
||||
|
@ -328,13 +365,18 @@ start training.
|
|||
- Single GPU run
|
||||
- run `bash scripts/ds_pretrain_gpt2.sh`
|
||||
- Multiple GPUs/Nodes run
|
||||
- run `bash scripts/ds_pretrain_gpt2_model_parallel.sh`
|
||||
- run `bash scripts/ds_zero2_pretrain_gpt2_model_parallel.sh`
|
||||
|
||||
## DeepSpeed Evaluation using GPT-2
|
||||
|
||||
|
||||
## Performance Improvements
|
||||
DeepSpeed enables training very large models effectively via the advanced [ZeRO
|
||||
optimizer](https://arxiv.org/abs/1910.02054v2). ZeRO significantly reduces the memory
|
||||
optimizer](https://arxiv.org/abs/1910.02054v2). In February, we released a sub-set
|
||||
of optimizations from ZeRO in DeepSpeed that performs optimizer state partitioning.
|
||||
We refer to them as ZeRO-1. In May, 2020 we extended ZeRO-1 in DeepSpeed to include
|
||||
additional optimizations from ZeRO including gradient and activation partitioning,
|
||||
as well as contiguous memory optimizations. We refer to this release as ZeRO-2.
|
||||
|
||||
ZeRO-2 significantly reduces the memory
|
||||
footprint for training large models which means large models can be trained with i) less
|
||||
model parallelism and ii) larger batch sizes. A lower model parallelism degree improves
|
||||
training efficiency by increasing the granularity of the computation such as the matrix
|
||||
|
@ -342,80 +384,25 @@ multiplication where performance is directly related to the size of the matrices
|
|||
Furthermore, less model parallelism also results in less communication between model
|
||||
parallel GPUs, which further boosts performance. Larger batch size has a similar effect
|
||||
of increasing the computational granularity as well as reducing communication, also
|
||||
resulting in better performance. Therefore, DeepSpeed combines ZeRO-powered data parallelism with
|
||||
Megatron-LM tensor-slicing model parallelism, which is
|
||||
significantly faster than using Megatron-LM alone.
|
||||
resulting in better performance. Therefore, with DeepSpeed and ZeRO-2 integration into Megatron,
|
||||
we elevate the model scale and speed to an entirely new level compared to Megatron alone.
|
||||
|
||||
The observed performance improvements depend on several factors such as the memory per
|
||||
GPU, the local GPU interconnect (i.e., PCI-E vs NVLINK vs NVSwitch), the model size,
|
||||
inter node network interconnect, etc. Below, we show some of the performance improvements
|
||||
from using DeepSpeed over Megatron on a 16 GPU Low Bandwidth (40 Gbps) cluster and a 400 GPU DGX-2 High Bandwidth (800 Gbps) cluster.
|
||||
For details please see the [ZeRO Paper](https://arxiv.org/abs/1910.02054v2). We also
|
||||
present performance improvement on a 64 GPU cluster along with detailed configuration
|
||||
analysis to show where the improvements come from.
|
||||
|
||||
![DeepSpeed-vs-Megatron](/assets/images/DeepSpeed-vs-Megatron.png)
|
||||
![DeepSpeed-vs-Megatron](../assets/images/zero-full.png)
|
||||
<p align="center">
|
||||
<em>The figure depicts system throughput improvements of DeepSpeed (combining ZeRO-powered data parallelism with model parallelism of Nvidia Megatron-LM) over using Megatron-LM alone.</em>
|
||||
<em>Figure 2: ZeRO-2 scales to 170 billion parameters, has up to 10x higher throughput, obtains super linear speedup, and improves usability by avoiding the need for code refactoring for models up to 13 billion parameters.</em>
|
||||
</p>
|
||||
|
||||
|
||||
### On Low Bandwidth GPU Cluster
|
||||
The figure above shows that training 1.5B parameter model with DeepSpeed is
|
||||
nearly 4x faster than without DeepSpeed on a cluster with 4 nodes, 4 GPU per
|
||||
node, and 16 GPUs total. These GPUs have 16GB of memory each, and PCI-E
|
||||
interconnects GPUs within a node, and 40 Gbps infiniband across nodes.
|
||||
|
||||
The performance improvement comes from lower model parallelism degree and
|
||||
larger batch size as discussed earlier. Training 1.5B parameter model with
|
||||
Megatron-LM alone requires 4-way model parallelism, and can only fit an effective
|
||||
batch size of 32 using all 16 GPUs. On the other hand, DeepSpeed does not
|
||||
require any model-parallelism to train this model, and can support an
|
||||
effective batch size of 128 without running out of memory, resulting in
|
||||
significantly higher performance.
|
||||
More concretely, DeepSpeed and ZeRO-2 excel in four aspects (as visualized in Figure 2), supporting an order-of-magnitude bigger models, up to 10x faster, with superlinear scalability, and improved usability to democratize large model training. These four aspects are detailed below.
|
||||
|
||||
|
||||
### On High bandwidth DGX-2 GPU Cluster
|
||||
Each GPU on the DGX-2 cluster has 32 GB of memory, and GPUs inside a box is connected via
|
||||
the high-bandwidth NVSwitch. DGX-2 nodes are connected to each other via 800 Gbps (8 x 100Gbps) infiniband interconnect. As such, running a 1.5B model on DGX-2 requires less model
|
||||
parallelism, and the performance improvement from DeepSpeed for this model size is less
|
||||
significant. However, at larger model sizes, Megatron still requires significantly larger
|
||||
model parallelism degree, and can only run much smaller batch sizes than DeepSpeed.
|
||||
Therefore, as the model sizes get larger, DeepSpeed, by coming ZeRO with Megatron model parallelism, starts to significantly outperform
|
||||
using Megatron-LM alone.
|
||||
Figure 2: ZeRO-2 scales to 170 billion parameters, has up to 10x higher throughput, obtains super linear speedup, and improves usability by avoiding the need for code refactoring for models up to 13 billion parameters.
|
||||
|
||||
Model size: State-of-the-art large models such as OpenAI GPT-2, NVIDIA Megatron-LM, Google T5, and Microsoft Turing-NLG have sizes of 1.5B, 8.3B, 11B, and 17B parameters respectively. ZeRO-2 provides system support to efficiently run models of 170 billion parameters, an order-of-magnitude bigger than these largest models (Figure 2, top left).
|
||||
|
||||
### Performance Improvements with Configuration Details
|
||||
The figure below compares DeepSpeed with Megatron on a 64 GPU cluster with 4
|
||||
DGX-2 nodes. To give the readers a clear idea of source of the performance
|
||||
improvements, we also present the configuration table for both Megatron and
|
||||
DeepSpeed. It shows the smallest model parallelism degree and the largest batch
|
||||
size that can be used to train these models without running out of memory. As
|
||||
discussed above, the tables demonstrate that DeepSpeed runs with smaller model parallelism degree
|
||||
and achieves better performance.
|
||||
Speed: Improved memory efficiency powers higher throughput and faster training. Figure 2 (bottom left) shows system throughput of ZeRO-2 and ZeRO-1 (both combining ZeRO-powered data parallelism with NVIDIA Megatron-LM model parallelism) as well as using the state-of-the-art model parallelism approach Megatron-LM alone (baseline in Figure 2, bottom left). ZeRO-2 runs 100-billion-parameter models on a 400 NVIDIA V100 GPU cluster with over 38 teraflops per GPU and aggregated performance over 15 petaflops. For models of the same size, ZeRO-2 is 10x faster in training speed when compared with using Megatron-LM alone and 5x faster when compared with ZeRO-1.
|
||||
|
||||
![DeepSpeed Performance SpeedUp](/assets/images/megatron-gpt2-perf-test.png)
|
||||
<p align="center">
|
||||
<em>The figure depicts system throughput improvements of DeepSpeed (combining ZeRO-powered data parallelism with model parallelism of Nvidia Megatron-LM) over using Megatron-LM alone.</em>
|
||||
</p>
|
||||
Scalability: We observe superlinear speedup (Figure 2, top right), where the performance more than doubles when the number of GPUs are doubled. ZeRO-2 reduces the memory footprint of the model states as we increase the data parallelism degree, allowing us to fit larger batch sizes per GPU and resulting in better performance.
|
||||
|
||||
Democratizing large model training: ZeRO-2 empowers model scientists to train models up to 13 billion parameters efficiently without any model parallelism that typically requires model refactoring (Figure 2, bottom right). 13 billion parameters is larger than most of the largest state-of-the-art models (such as Google T5, with 11 billion parameters). Model scientists can therefore experiment freely with large models without worrying about model parallelism. In comparison, the implementations of classic data-parallelism approaches (such as PyTorch Distributed Data Parallel) run out of memory with 1.4-billion-parameter models, while ZeRO-1 supports up to 6 billion parameters for comparison.
|
||||
|
||||
**a ) Megatron-LM GPT2 Baseline**
|
||||
|
||||
| | Model Parallelism | Data Parallelism | #gpus | batch size | layers | hidden size | attention heads | samples / sec |
|
||||
| ---- | ----------------: | ---------------: | ----: | ---------: | -----: | -----------:| --------------: | ------------: |
|
||||
| 1.5B | 2 | 32 | 64 | 512 | 48 | 1600 | 16 | 128.56 |
|
||||
| 4B | 4 | 16 | 64 | 128 | 64 | 2304 | 16 | 49.36 |
|
||||
| 8B | 4 | 16 | 64 | 128 | 72 | 3072 | 24 | 24.57 |
|
||||
| 20B | 16 | 4 | 64 | 16 | 111 | 3808 | 32 | 3.42 |
|
||||
|
||||
|
||||
|
||||
**b ) Megatron-LM GPT2 with DeepSpeed**
|
||||
|
||||
| | Model Parallelism | Data Parallelism | #gpus | batch size | layers | hidden size | attention heads | samples / sec |
|
||||
| ---- | ----------------: | ---------------: | ----: | ---------: | -----: | -----------:| --------------: | ------------: |
|
||||
| 1.5B | 1 | 64 | 64 | 2048 | 48 | 1600 | 16 | 151.35 |
|
||||
| 4B | 1 | 64 | 64 | 512 | 64 | 2304 | 16 | 75.13 |
|
||||
| 8B | 2 | 32 | 64 | 512 | 72 | 3072 | 24 | 43.52 |
|
||||
| 20B | 4 | 16 | 64 | 128 | 111 | 3808 | 32 | 12.65 |
|
||||
Furthermore, in the absence of model parallelism, these models can be trained on low bandwidth clusters while still achieving significantly better throughput compared to using model parallelism. For example, the GPT-2 model can be trained nearly 4x faster with ZeRO powered data parallelism compared to using model parallelism on a four node cluster connected with 40 Gbps Infiniband interconnect, where each node have four NVIDIA 16GB V100 GPUs connected with PCI-E. Therefore, with this performance improvement, large model training is no longer limited to GPU clusters with ultra fast interconnect but also accesible on modest clusters with limited bandwidth.
|
||||
|
|
До Ширина: | Высота: | Размер: 96 KiB После Ширина: | Высота: | Размер: 96 KiB |
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 78 KiB |
До Ширина: | Высота: | Размер: 13 KiB После Ширина: | Высота: | Размер: 13 KiB |
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 121 KiB |
|
@ -1,3 +1,2 @@
|
|||
tqdm
|
||||
psutil
|
||||
tensorboardX==1.8
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
Activation Checkpointing
|
||||
========================
|
||||
|
||||
The activation checkpointing API's in DeepSpeed can be used to enable a range
|
||||
of memory optimizations relating to activation checkpointing. These include
|
||||
activation partitioning across GPUs when using model parallelism, CPU
|
||||
checkpointing, contiguous memory optimizations, etc.
|
||||
|
||||
Please see the `DeepSpeed JSON config <https://www.deepspeed.ai/docs/config-json/>`_
|
||||
for the full set.
|
||||
|
||||
Here we present the activation checkpointing API. Please see the enabling
|
||||
DeepSpeed for `Megatron-LM tutorial <https://www.deepspeed.ai/tutorials/megatron/>`_
|
||||
for example usage.
|
||||
|
||||
Configuring Activation Checkpointing
|
||||
------------------------------------
|
||||
.. autofunction:: deepspeed.checkpointing.configure
|
||||
|
||||
.. autofunction:: deepspeed.checkpointing.is_configured
|
||||
|
||||
|
||||
Using Activation Checkpointing
|
||||
------------------------------
|
||||
.. autofunction:: deepspeed.checkpointing.checkpoint
|
||||
|
||||
.. autofunction:: deepspeed.checkpointing.reset
|
||||
|
||||
|
||||
Configuring and Checkpointing Random Seeds
|
||||
------------------------------------------
|
||||
.. autofunction:: deepspeed.checkpointing.get_cuda_rng_tracker
|
||||
|
||||
.. autofunction:: deepspeed.checkpointing.model_parallel_cuda_manual_seed
|
||||
|
||||
.. autoclass:: deepspeed.checkpointing.CudaRNGStatesTracker
|
||||
|
||||
.. autoclass:: deepspeed.checkpointing.CheckpointFunction
|
|
@ -0,0 +1,26 @@
|
|||
DeepSpeed Activation Checkpointing
|
||||
======================
|
||||
|
||||
The activation checkpointing API's in DeepSpeed can be used to enable a range of memory optimizations relating
|
||||
to activation checkpointing. These include activation partitioning across
|
||||
GPUs when using model parallelism, CPU Checkpointing, contiguous memory optimizations, etc.
|
||||
Please see the `DeepSpeed JSON config <https://www.deepspeed.ai/docs/config-json/>`_ for the full set.
|
||||
|
||||
Here we present the activation checkpointing API's.
|
||||
Please see the enabling DeepSpeed for Megatron-LM tutorial for usage details.
|
||||
|
||||
.. autofunction:: deepspeed.checkpointing.configure
|
||||
|
||||
.. autofunction:: deepspeed.checkpointing.is_configured
|
||||
|
||||
.. autofunction:: deepspeed.checkpointing.checkpoint
|
||||
|
||||
.. autofunction:: deepspeed.checkpointing.reset
|
||||
|
||||
.. autofunction:: deepspeed.checkpointing.get_cuda_rng_tracker
|
||||
|
||||
.. autofunction:: deepspeed.checkpointing.model_parallel_cuda_manual_seed
|
||||
|
||||
.. autoclass:: deepspeed.checkpointing.CudaRNGStatesTracker
|
||||
|
||||
.. autoclass:: deepspeed.checkpointing.CheckpointFunction
|
|
@ -71,25 +71,9 @@ html_context = {
|
|||
from unittest.mock import MagicMock
|
||||
sys.path.insert(0, os.path.abspath('../../../'))
|
||||
|
||||
# Prepend module names to class descriptions?
|
||||
add_module_names = True
|
||||
|
||||
class Mock(MagicMock):
|
||||
@classmethod
|
||||
def __getattr__(cls, name):
|
||||
return MagicMock()
|
||||
autoclass_content = 'both'
|
||||
|
||||
|
||||
MOCK_MODULES = [
|
||||
'torch',
|
||||
'torch.utils',
|
||||
'torch.utils.data',
|
||||
'torch.utils.data.distributed',
|
||||
'torch._utils',
|
||||
'torch.cuda',
|
||||
'torch.nn.modules',
|
||||
'torch.nn',
|
||||
'torch.distributed',
|
||||
'torch.distributed.distributed_c10d',
|
||||
'torch.optim',
|
||||
'torch._six'
|
||||
]
|
||||
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
|
||||
autodoc_mock_imports = ["torch", "apex", "mpi4py", "tensorboardX"]
|
||||
|
|
|
@ -1,15 +1,35 @@
|
|||
DeepSpeed
|
||||
=========
|
||||
|
||||
Model Setup
|
||||
-----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Contents:
|
||||
|
||||
initialize
|
||||
checkpointing
|
||||
|
||||
Training API
|
||||
------------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
training
|
||||
|
||||
|
||||
Checkpointing API
|
||||
-----------------
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
model-checkpointing
|
||||
activation-checkpointing
|
||||
|
||||
|
||||
Indices and tables
|
||||
==================
|
||||
------------------
|
||||
|
||||
* :ref:`genindex`
|
||||
* :ref:`modindex`
|
||||
|
|
|
@ -1,6 +1,30 @@
|
|||
Initializing DeepSpeed
|
||||
======================
|
||||
Training Setup
|
||||
==============
|
||||
|
||||
.. _deepspeed-args:
|
||||
|
||||
Argument Parsing
|
||||
----------------
|
||||
DeepSpeed uses the `argparse <https://docs.python.org/3/library/argparse.html>`_ library to
|
||||
supply commandline configuration to the DeepSpeed runtime. Use ``deepspeed.add_config_arguments()``
|
||||
to add DeepSpeed's builtin arguments to your application's parser.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
parser = argparse.ArgumentParser(description='My training script.')
|
||||
parser.add_argument('--local_rank', type=int, default=-1,
|
||||
help='local rank passed from distributed launcher')
|
||||
# Include DeepSpeed configuration arguments
|
||||
parser = deepspeed.add_config_arguments(parser)
|
||||
cmd_args = parser.parse_args()
|
||||
|
||||
.. autofunction:: deepspeed.add_config_arguments
|
||||
|
||||
|
||||
.. _deepspeed-init:
|
||||
|
||||
Training Initialization
|
||||
-----------------------
|
||||
The entrypoint for all training with DeepSpeed is ``deepspeed.initialize()``.
|
||||
|
||||
Example usage:
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
Model Checkpointing
|
||||
===================
|
||||
|
||||
DeepSpeed provides routines for checkpointing model state during training.
|
||||
|
||||
Loading Training Checkpoints
|
||||
----------------------------
|
||||
.. autofunction:: deepspeed.DeepSpeedLight.load_checkpoint
|
||||
|
||||
Saving Training Checkpoints
|
||||
---------------------------
|
||||
.. autofunction:: deepspeed.DeepSpeedLight.save_checkpoint
|
|
@ -0,0 +1,29 @@
|
|||
Training API
|
||||
============
|
||||
|
||||
:func:`deepspeed.initialize` returns a *model engine* in its first argument
|
||||
of type ``DeepSpeedLight``. This engine is used to progress training:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
for step, batch in enumerate(data_loader):
|
||||
#forward() method
|
||||
loss = model_engine(batch)
|
||||
|
||||
#runs backpropagation
|
||||
model_engine.backward(loss)
|
||||
|
||||
#weight update
|
||||
model_engine.step()
|
||||
|
||||
Forward Propagation
|
||||
-------------------
|
||||
.. autofunction:: deepspeed.DeepSpeedLight.forward
|
||||
|
||||
Backward Propagation
|
||||
--------------------
|
||||
.. autofunction:: deepspeed.DeepSpeedLight.backward
|
||||
|
||||
Optimizer Step
|
||||
--------------
|
||||
.. autofunction:: deepspeed.DeepSpeedLight.step
|
|
@ -8,11 +8,11 @@ DeepSpeed is a deep learning optimization library that makes distributed trainin
|
|||
efficient, and effective.
|
||||
|
||||
<p align="center"><i><b>10x Larger Models</b></i></p>
|
||||
<p align="center"><i><b>5x Faster Training</b></i></p>
|
||||
<p align="center"><i><b>10x Faster Training</b></i></p>
|
||||
<p align="center"><i><b>Minimal Code Change</b></i></p>
|
||||
|
||||
DeepSpeed can train DL models with over a hundred billion parameters on current
|
||||
generation of GPU clusters, while achieving over 5x in system performance
|
||||
generation of GPU clusters, while achieving over 10x in system performance
|
||||
compared to the state-of-art. Early adopters of DeepSpeed have already produced
|
||||
a language model (LM) with over 17B parameters called
|
||||
[Turing-NLG](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft),
|
||||
|
@ -22,9 +22,9 @@ establishing a new SOTA in the LM category.
|
|||
{% assign news = site.posts | where: "sneak_preview", "false" %}
|
||||
{% for post in news limit:5 %}
|
||||
{% if post.link %}
|
||||
* [{{ post.title }}]({{ post.link }})
|
||||
* [{{ post.date | date: "%Y/%m/%d" }}] [{{ post.title }}]({{ post.link }}) {% if post.new_post %} <span style="color:dodgerblue">**NEW!**</span> {% endif %}
|
||||
{% else %}
|
||||
* [{{ post.title }}]({{ post.url }})
|
||||
* [{{ post.date | date: "%Y/%m/%d"}}] [{{ post.title }}]({{ post.url }}) {% if post.new_post %} <span style="color:dodgerblue">**NEW!**</span> {% endif %}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
|
@ -54,19 +54,20 @@ DeepSpeed achieves high performance and fast convergence through a combination o
|
|||
efficiency optimizations on compute/communication/memory/IO and effectiveness
|
||||
optimizations on advanced hyperparameter tuning and optimizers. For example:
|
||||
|
||||
* DeepSpeed trains BERT-large to parity in 14 hours using 64 GPUs (4 DGX-2 boxes) and in
|
||||
3.7 hours using 256 GPUs (16 DGX-2 boxes).
|
||||
* <span style="color:dodgerblue">DeepSpeed trains BERT-large to parity in 44
|
||||
mins using 1024 V100 GPUs (64 DGX-2 boxes) and in 2.4 hours using 256 GPUs
|
||||
(16 DGX-2 boxes).</span>
|
||||
|
||||
**BERT-large Training Times**
|
||||
|
||||
| Devices | Source | Training Time (hours) |
|
||||
| ------------- | --------- | ---------------------:|
|
||||
| 64 TPUs | Google | 96 |
|
||||
| 64 V100 GPUs | DeepSpeed | **14** |
|
||||
| 256 V100 GPUs | NVIDIA | 3.9 |
|
||||
| 256 V100 GPUs | DeepSpeed | **3.7** |
|
||||
| Devices | Source | Training Time |
|
||||
| -------------- | --------- | ---------------------:|
|
||||
| 1024 V100 GPUs | DeepSpeed | **44** min|
|
||||
| 256 V100 GPUs | DeepSpeed | **2.4** hr|
|
||||
| 64 V100 GPUs | DeepSpeed | **8.68** hr|
|
||||
| 16 V100 GPUs | DeepSpeed | **33.22** hr|
|
||||
|
||||
*BERT Tutorial*: Coming Soon
|
||||
*BERT codes and tutorials will be available soon.*
|
||||
|
||||
* DeepSpeed trains GPT2 (1.5 billion parameters) 3.75x faster than state-of-art, NVIDIA
|
||||
Megatron on Azure GPUs.
|
||||
|
@ -77,37 +78,42 @@ optimizations on advanced hyperparameter tuning and optimizers. For example:
|
|||
|
||||
## Memory efficiency
|
||||
DeepSpeed provides memory-efficient data parallelism and enables training models without
|
||||
model parallelism. For example, DeepSpeed can train models with up to 6 billion parameters on
|
||||
model parallelism. For example, DeepSpeed can train models with up to 13 billion parameters on
|
||||
NVIDIA V100 GPUs with 32GB of device memory. In comparison, existing frameworks (e.g.,
|
||||
PyTorch's Distributed Data Parallel) run out of memory with 1.5 billion parameter models.
|
||||
PyTorch's Distributed Data Parallel) run out of memory with 1.4 billion parameter models.
|
||||
|
||||
DeepSpeed reduces the training memory footprint through a novel solution called Zero
|
||||
Redundancy Optimizer (ZeRO). Unlike basic data parallelism where memory states are
|
||||
replicated across data-parallel processes, ZeRO partitions model states to save
|
||||
significant memory. The current implementation (stage 1 of ZeRO) reduces memory by up to
|
||||
4x relative to the state-of-art. You can read more about ZeRO in our [paper](https://arxiv.org/abs/1910.02054).
|
||||
replicated across data-parallel processes, ZeRO partitions model states and gradients to save
|
||||
significant memory. Furthermore, it also reduces activation memory and fragmented memory.
|
||||
The current implementation (ZeRO-2) reduces memory by up to
|
||||
8x relative to the state-of-art. You can read more about ZeRO in our [paper](https://arxiv.org/abs/1910.02054), and
|
||||
in our blog posts related to
|
||||
[ZeRO-1](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/). <!-- and [ZeRO-2](linklink). -->
|
||||
|
||||
With this impressive memory reduction, early adopters of DeepSpeed have already
|
||||
produced a language model (LM) with over 17B parameters called
|
||||
[Turing-NLG](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft),
|
||||
<a href="https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft">
|
||||
<span style="color:dodgerblue">Turing-NLG</span></a>,
|
||||
establishing a new SOTA in the LM category.
|
||||
|
||||
|
||||
## Scalability
|
||||
DeepSpeed supports efficient data parallelism, model parallelism, and their
|
||||
combination. ZeRO boosts the scaling capability and efficiency further.
|
||||
* DeepSpeed provides system support to run models up to 100 billion parameters,
|
||||
10x larger than the state-of-art (8 billion NVIDIA GPT, 11 billion Google T5).
|
||||
* DeepSpeed can run large models more efficiently, up to 6x faster for models with
|
||||
various sizes spanning 1.5B to 100B. More specifically, the data parallelism powered by ZeRO
|
||||
* <span style="color:dodgerblue">DeepSpeed provides system support to run models up to 170 billion parameters,
|
||||
10x larger than the state-of-art (8 billion NVIDIA GPT, 11 billion Google T5).</span>
|
||||
* <span style="color:dodgerblue">DeepSpeed can run large models more efficiently, up to 10x
|
||||
faster for models with
|
||||
various sizes spanning 1.5B to 170B.</span> More specifically, the data parallelism powered by ZeRO
|
||||
is complementary and can be combined with different types of model parallelism. It allows
|
||||
DeepSpeed to fit models using lower degree of model parallelism and higher batch size, offering
|
||||
significant performance gains compared to using model parallelism alone.
|
||||
|
||||
*Read more*: [technical report](https://arxiv.org/abs/1910.02054),
|
||||
*Read more*: [ZeRO paper](https://arxiv.org/abs/1910.02054),
|
||||
and [GPT tutorial](/tutorials/megatron).
|
||||
|
||||
![DeepSpeed-vs-Megatron](/assets/images/DeepSpeed-vs-Megatron.png)
|
||||
![DeepSpeed Speedup](/assets/images/deepspeed-speedup.png)
|
||||
<p align="center">
|
||||
<em>The figure depicts system throughput improvements of DeepSpeed (combining ZeRO-powered data parallelism with model parallelism of NVIDIA Megatron-LM) over using Megatron-LM alone.</em>
|
||||
</p>
|
||||
|
@ -123,7 +129,7 @@ convergence to desired accuracy.
|
|||
|
||||
|
||||
## Good Usability
|
||||
Only a few lines of code changes are needed to enable a PyTorch model to use DeepSpeed and ZeRO. Compared to current model parallelism libraries, DeepSpeed does not require a code redesign or model refactoring. It also does not put limitations on model dimensions (such as number of attention heads, hidden sizes, and others), batch size, or any other training parameters. For models of up to six billion parameters, you can use ZeRO-powered data parallelism conveniently without requiring model parallelism, while in contrast, standard data parallelism will run out of memory for models with more than 1.3 billion parameters. In addition, DeepSpeed conveniently supports flexible combination of ZeRO-powered data parallelism with custom model parallelisms, such as tensor slicing of NVIDIA's Megatron-LM.
|
||||
Only a few lines of code changes are needed to enable a PyTorch model to use DeepSpeed and ZeRO. Compared to current model parallelism libraries, DeepSpeed does not require a code redesign or model refactoring. It also does not put limitations on model dimensions (such as number of attention heads, hidden sizes, and others), batch size, or any other training parameters. For models of up to 13 billion parameters, you can use ZeRO-powered data parallelism conveniently without requiring model parallelism, while in contrast, standard data parallelism will run out of memory for models with more than 1.4 billion parameters. In addition, DeepSpeed conveniently supports flexible combination of ZeRO-powered data parallelism with custom model parallelisms, such as tensor slicing of NVIDIA's Megatron-LM.
|
||||
|
||||
|
||||
## Features
|
||||
|
@ -137,12 +143,17 @@ overview](features) for descriptions and usage.
|
|||
* [Model Parallelism](features.md#model-parallelism)
|
||||
* Support for Custom Model Parallelism
|
||||
* Integration with Megatron-LM
|
||||
* [Memory and Bandwidth Optimizations](features.md#memory-and-bandwidth-optimizations)
|
||||
* The Zero Redundancy Optimizer (ZeRO)
|
||||
* Constant Buffer Optimization (CBO)
|
||||
* [The Zero Redundancy Optimizer (ZeRO)](features.md#the-zero-redundancy-optimizer)
|
||||
* Optimizer State and Gradient Partitioning
|
||||
* Activation Partitioning
|
||||
* Constant Buffer Optimization
|
||||
* Contiguous Memory Optimization
|
||||
* [Additional Memory and Bandwidth Optimizations](features.md#additional-memory-and-bandwidth-optimizations)
|
||||
* Smart Gradient Accumulation
|
||||
* Communication/Computation Overlap
|
||||
* [Training Features](features.md#training-features)
|
||||
* Simplified training API
|
||||
* Activation Checkpointing API
|
||||
* Gradient Clipping
|
||||
* Automatic loss scaling with mixed precision
|
||||
* [Training Optimizers](features.md#training-optimizers)
|
||||
|
|
|
@ -56,6 +56,19 @@ class BingBertSquadFuncTestCase(BaseTestCase):
|
|||
succ = self.run_test(test_config, 0.01)
|
||||
self.assertTrue(succ)
|
||||
|
||||
def test_gpu4_fp16_zero2(self):
|
||||
test_config = {
|
||||
"gpus": 4,
|
||||
"deepspeed": False,
|
||||
"json": "deepspeed_bsz24_fp16_zero2_config.json",
|
||||
"max_steps": 8,
|
||||
"max_epoch_steps": 4,
|
||||
"other_args": "--fp16 --print_steps 1"
|
||||
}
|
||||
|
||||
succ = self.run_test(test_config, 0.01)
|
||||
self.assertTrue(succ)
|
||||
|
||||
def test_gpu1_fp16(self):
|
||||
test_config = {
|
||||
"gpus": 1,
|
||||
|
@ -151,6 +164,7 @@ class BingBertSquadFuncTestCase(BaseTestCase):
|
|||
def suite():
|
||||
suite = unittest.TestSuite()
|
||||
suite.addTest(BingBertSquadFuncTestCase('test_gpu4_fp16'))
|
||||
suite.addTest(BingBertSquadFuncTestCase('test_gpu4_fp16_zero2'))
|
||||
suite.addTest(BingBertSquadFuncTestCase('test_gpu1_fp16'))
|
||||
suite.addTest(BingBertSquadFuncTestCase('test_gpu4_fp32'))
|
||||
suite.addTest(BingBertSquadFuncTestCase('test_gpu1_fp32'))
|
||||
|
|
|
@ -1,14 +1,6 @@
|
|||
{
|
||||
"tensorboard": {
|
||||
"enabled": false,
|
||||
"job_name": "MyJob"
|
||||
},
|
||||
"zero_optimization": true,
|
||||
"disable_allgather": false,
|
||||
"allgather_size": 200000,
|
||||
"wall_clock_breakdown": false,
|
||||
"train_batch_size": 24,
|
||||
"train_micro_batch_size_per_gpu": 3,
|
||||
"train_micro_batch_size_per_gpu": 6,
|
||||
"steps_per_print": 1,
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
|
@ -21,5 +13,8 @@
|
|||
"gradient_clipping": 1.0,
|
||||
"fp16": {
|
||||
"enabled": true
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 1
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
{
|
||||
"train_batch_size": 24,
|
||||
"train_micro_batch_size_per_gpu": 6,
|
||||
"steps_per_print": 1,
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"params": {
|
||||
"lr": 3e-5,
|
||||
"weight_decay": 0.0,
|
||||
"bias_correction": false
|
||||
}
|
||||
},
|
||||
"gradient_clipping": 1.0,
|
||||
"fp16": {
|
||||
"enabled": true
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2
|
||||
}
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"train_batch_size": 24,
|
||||
"train_micro_batch_size_per_gpu": 3,
|
||||
"train_micro_batch_size_per_gpu": 6,
|
||||
"steps_per_print": 1,
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
|
|
|
@ -122,7 +122,7 @@ echo "deepspeed: ${enable_deepspeed}"
|
|||
echo "other_args: ${other_args}"
|
||||
|
||||
EFFECTIVE_BATCH_SIZE=${batch_size}
|
||||
MAX_GPU_BATCH_SIZE=3
|
||||
MAX_GPU_BATCH_SIZE=6
|
||||
PER_GPU_BATCH_SIZE=$((EFFECTIVE_BATCH_SIZE/num_gpus))
|
||||
if [[ $PER_GPU_BATCH_SIZE -lt $MAX_GPU_BATCH_SIZE ]]; then
|
||||
GRAD_ACCUM_STEPS=1
|
||||
|
|
5
tests/model/Megatron_GPT2/ds_config_func_bs4.json → tests/model/Megatron_GPT2/ds_config_func_bs4_zero1.json
Normal file → Executable file
5
tests/model/Megatron_GPT2/ds_config_func_bs4.json → tests/model/Megatron_GPT2/ds_config_func_bs4_zero1.json
Normal file → Executable file
|
@ -2,10 +2,11 @@
|
|||
"train_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"steps_per_print": 1,
|
||||
"zero_optimization": true,
|
||||
"zero_optimization": {
|
||||
"stage":1
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"legacy_fusion": false,
|
||||
"params": {
|
||||
"lr": 0.00015,
|
||||
"max_grad_norm": 1.0
|
|
@ -0,0 +1,23 @@
|
|||
{
|
||||
"train_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"steps_per_print": 1,
|
||||
"zero_optimization": {
|
||||
"stage":2
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"params": {
|
||||
"lr": 0.00015,
|
||||
"max_grad_norm": 1.0
|
||||
}
|
||||
},
|
||||
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
}
|
||||
}
|
|
@ -2,12 +2,14 @@
|
|||
"train_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"steps_per_print": 1,
|
||||
"zero_optimization": false,
|
||||
"zero_optimization": {
|
||||
"stage":0
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"legacy_fusion": false,
|
||||
"params": {
|
||||
"lr": 0.00015
|
||||
"lr": 0.00015,
|
||||
"max_grad_norm": 1.0
|
||||
}
|
||||
},
|
||||
|
||||
|
|
5
tests/model/Megatron_GPT2/ds_config_func_bs8.json → tests/model/Megatron_GPT2/ds_config_func_bs8_zero1.json
Normal file → Executable file
5
tests/model/Megatron_GPT2/ds_config_func_bs8.json → tests/model/Megatron_GPT2/ds_config_func_bs8_zero1.json
Normal file → Executable file
|
@ -2,10 +2,11 @@
|
|||
"train_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"steps_per_print": 1,
|
||||
"zero_optimization": true,
|
||||
"zero_optimization":{
|
||||
"stage":1
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"legacy_fusion": false,
|
||||
"params": {
|
||||
"lr": 0.00015,
|
||||
"max_grad_norm": 1.0
|
|
@ -0,0 +1,28 @@
|
|||
{
|
||||
"train_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"steps_per_print": 1,
|
||||
"zero_optimization": {
|
||||
"stage":2
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"params": {
|
||||
"lr": 0.00015,
|
||||
"max_grad_norm": 1.0
|
||||
}
|
||||
},
|
||||
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"activation_checkpointing": {
|
||||
"partition_activations": true,
|
||||
"contiguous_memory_optimization": true
|
||||
}
|
||||
|
||||
}
|
|
@ -2,10 +2,11 @@
|
|||
"train_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"steps_per_print": 1,
|
||||
"zero_optimization": true,
|
||||
"zero_optimization": {
|
||||
"stage":2
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"legacy_fusion": false,
|
||||
"params": {
|
||||
"lr": 0.00015,
|
||||
"max_grad_norm": 1.0
|
||||
|
|
|
@ -2,11 +2,10 @@
|
|||
"train_batch_size": 16,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"steps_per_print": 1,
|
||||
"zero_optimization": true,
|
||||
"zero_optimization": 1,
|
||||
"disable_allgather": true,
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"legacy_fusion": false,
|
||||
"params": {
|
||||
"lr": 0.00015,
|
||||
"max_grad_norm": 1.0
|
||||
|
|
|
@ -2,11 +2,12 @@
|
|||
"train_batch_size": 32,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"steps_per_print": 1,
|
||||
"zero_optimization": true,
|
||||
"zero_optimization": {
|
||||
"stage":1
|
||||
},
|
||||
"disable_allgather": true,
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"legacy_fusion": false,
|
||||
"params": {
|
||||
"lr": 0.00015,
|
||||
"max_grad_norm": 1.0
|
||||
|
|
|
@ -2,11 +2,10 @@
|
|||
"train_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"steps_per_print": 1,
|
||||
"zero_optimization": true,
|
||||
"zero_optimization": 1,
|
||||
"disable_allgather": true,
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"legacy_fusion": false,
|
||||
"params": {
|
||||
"lr": 0.00015,
|
||||
"max_grad_norm": 1.0
|
||||
|
|
|
@ -85,6 +85,7 @@ gpt_options=" \
|
|||
--checkpoint-activations \
|
||||
--checkpoint-num-layers ${ckpt_num_layers} \
|
||||
--fp16 \
|
||||
--cache-dir /tmp/cache_dir \
|
||||
--log-interval 1 \
|
||||
${other_args} \
|
||||
${ds_opt} \
|
||||
|
@ -92,7 +93,7 @@ gpt_options=" \
|
|||
"
|
||||
|
||||
work_dir="../../../DeepSpeedExamples/Megatron-LM/"
|
||||
run_cmd="(cd ${work_dir} && deepspeed --num_gpus $gpus pretrain_gpt2.py ${gpt_options})"
|
||||
run_cmd="(cd ${work_dir} && deepspeed --num_nodes $nodes --num_gpus $gpus pretrain_gpt2.py ${gpt_options})"
|
||||
echo ${run_cmd}
|
||||
eval ${run_cmd}
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ class GPT2CheckpointTestCase(BaseTestCase):
|
|||
def tearDown(self):
|
||||
os.chdir(self.save_dir)
|
||||
|
||||
def test_mp4_gpu16_node1_with_zero(self):
|
||||
def test_mp4_gpu16_node1_with_zero1(self):
|
||||
test_config = {
|
||||
"mp": 2,
|
||||
"gpus": 4,
|
||||
|
@ -55,12 +55,34 @@ class GPT2CheckpointTestCase(BaseTestCase):
|
|||
"seq_length": 256,
|
||||
"heads": 16,
|
||||
"deepspeed": True,
|
||||
"tag": "ds_zero",
|
||||
"tag": "ds_zero1",
|
||||
"zero": True,
|
||||
"other_args": "",
|
||||
"checkpoint_name": "ckpt_mp4_gpu16_w_zero",
|
||||
"checkpoint_name": "ckpt_mp4_gpu16_w_zero1",
|
||||
"checkpoint_interval": 1000,
|
||||
"json": "ds_config_func_bs8.json",
|
||||
"json": "ds_config_func_bs8_zero1.json",
|
||||
}
|
||||
succ = self.run_test(test_config, 0.01)
|
||||
self.assertTrue(succ)
|
||||
|
||||
def test_mp4_gpu16_node1_with_zero2(self):
|
||||
test_config = {
|
||||
"mp": 2,
|
||||
"gpus": 4,
|
||||
"nodes": 1,
|
||||
"bs": 8,
|
||||
"steps": 1100,
|
||||
"layers": 2,
|
||||
"hidden_size": 256,
|
||||
"seq_length": 256,
|
||||
"heads": 16,
|
||||
"deepspeed": True,
|
||||
"tag": "ds_zero2",
|
||||
"zero": True,
|
||||
"other_args": "",
|
||||
"checkpoint_name": "ckpt_mp4_gpu16_w_zero2",
|
||||
"checkpoint_interval": 1000,
|
||||
"json": "ds_config_func_bs8_zero2.json",
|
||||
}
|
||||
succ = self.run_test(test_config, 0.01)
|
||||
self.assertTrue(succ)
|
||||
|
@ -184,7 +206,8 @@ class GPT2CheckpointTestCase(BaseTestCase):
|
|||
|
||||
def checkpoint_suite():
|
||||
suite = unittest.TestSuite()
|
||||
suite.addTest(GPT2CheckpointTestCase('test_mp4_gpu16_node1_with_zero'))
|
||||
suite.addTest(GPT2CheckpointTestCase('test_mp4_gpu16_node1_with_zero1'))
|
||||
suite.addTest(GPT2CheckpointTestCase('test_mp4_gpu16_node1_with_zero2'))
|
||||
suite.addTest(GPT2CheckpointTestCase('test_mp4_gpu16_node1_without_zero'))
|
||||
|
||||
return suite
|
||||
|
|
|
@ -43,7 +43,7 @@ class GPT2FuncTestCase(BaseTestCase):
|
|||
def tearDown(self):
|
||||
os.chdir(self.save_dir)
|
||||
|
||||
def test_mp1_gpu1_node1(self):
|
||||
def test_mp1_gpu1_node1_zero1(self):
|
||||
test_config = {
|
||||
"mp": 1,
|
||||
"gpus": 1,
|
||||
|
@ -55,13 +55,13 @@ class GPT2FuncTestCase(BaseTestCase):
|
|||
"seq_length": 256,
|
||||
"heads": 12,
|
||||
"deepspeed": False,
|
||||
"json": "ds_config_func_bs4.json",
|
||||
"json": "ds_config_func_bs4_zero1.json",
|
||||
}
|
||||
|
||||
succ = self.run_test(test_config, 0.01)
|
||||
self.assertTrue(succ)
|
||||
|
||||
def test_mp1_gpu2_node1(self):
|
||||
def test_mp1_gpu2_node1_zero1(self):
|
||||
test_config = {
|
||||
"mp": 1,
|
||||
"gpus": 2,
|
||||
|
@ -73,13 +73,13 @@ class GPT2FuncTestCase(BaseTestCase):
|
|||
"seq_length": 256,
|
||||
"heads": 12,
|
||||
"deepspeed": False,
|
||||
"json": "ds_config_func_bs8.json",
|
||||
"json": "ds_config_func_bs8_zero1.json",
|
||||
}
|
||||
|
||||
succ = self.run_test(test_config, 0.01)
|
||||
self.assertTrue(succ)
|
||||
|
||||
def test_mp2_gpu4_node1(self):
|
||||
def test_mp2_gpu4_node1_zero1(self):
|
||||
test_config = {
|
||||
"mp": 2,
|
||||
"gpus": 4,
|
||||
|
@ -91,16 +91,13 @@ class GPT2FuncTestCase(BaseTestCase):
|
|||
"seq_length": 256,
|
||||
"heads": 12,
|
||||
"deepspeed": False,
|
||||
"json": "ds_config_func_bs8.json",
|
||||
"json": "ds_config_func_bs8_zero1.json",
|
||||
}
|
||||
|
||||
succ = self.run_test(test_config, 0.01)
|
||||
self.assertTrue(succ)
|
||||
|
||||
succ = self.run_partition_activations_test(test_config, 0.01)
|
||||
self.assertTrue(succ)
|
||||
|
||||
def test_mp4_gpu4_node1(self):
|
||||
def test_mp4_gpu4_node1_zero1(self):
|
||||
test_config = {
|
||||
"mp": 4,
|
||||
"gpus": 4,
|
||||
|
@ -112,7 +109,82 @@ class GPT2FuncTestCase(BaseTestCase):
|
|||
"seq_length": 256,
|
||||
"heads": 12,
|
||||
"deepspeed": False,
|
||||
"json": "ds_config_func_bs8.json",
|
||||
"json": "ds_config_func_bs8_zero1.json",
|
||||
}
|
||||
|
||||
succ = self.run_test(test_config, 0.01)
|
||||
self.assertTrue(succ)
|
||||
|
||||
def test_mp1_gpu1_node1_zero2(self):
|
||||
test_config = {
|
||||
"mp": 1,
|
||||
"gpus": 1,
|
||||
"nodes": 1,
|
||||
"bs": 4,
|
||||
"steps": 1000,
|
||||
"layers": 12,
|
||||
"hidden_size": 768,
|
||||
"seq_length": 256,
|
||||
"heads": 12,
|
||||
"deepspeed": False,
|
||||
"json": "ds_config_func_bs4_zero2.json",
|
||||
}
|
||||
|
||||
succ = self.run_test(test_config, 0.01)
|
||||
self.assertTrue(succ)
|
||||
|
||||
def test_mp1_gpu2_node1_zero2(self):
|
||||
test_config = {
|
||||
"mp": 1,
|
||||
"gpus": 2,
|
||||
"nodes": 1,
|
||||
"bs": 8,
|
||||
"steps": 1000,
|
||||
"layers": 12,
|
||||
"hidden_size": 768,
|
||||
"seq_length": 256,
|
||||
"heads": 12,
|
||||
"deepspeed": False,
|
||||
"json": "ds_config_func_bs8_zero2.json",
|
||||
}
|
||||
|
||||
succ = self.run_test(test_config, 0.01)
|
||||
self.assertTrue(succ)
|
||||
|
||||
def test_mp2_gpu4_node1_zero2(self):
|
||||
test_config = {
|
||||
"mp": 2,
|
||||
"gpus": 4,
|
||||
"nodes": 1,
|
||||
"bs": 8,
|
||||
"steps": 1000,
|
||||
"layers": 12,
|
||||
"hidden_size": 768,
|
||||
"seq_length": 256,
|
||||
"heads": 12,
|
||||
"deepspeed": False,
|
||||
"json": "ds_config_func_bs8_zero2.json",
|
||||
}
|
||||
|
||||
succ = self.run_test(test_config, 0.01)
|
||||
self.assertTrue(succ)
|
||||
|
||||
succ = self.run_partition_activations_test(test_config, 0.01)
|
||||
self.assertTrue(succ)
|
||||
|
||||
def test_mp4_gpu4_node1_zero2(self):
|
||||
test_config = {
|
||||
"mp": 4,
|
||||
"gpus": 4,
|
||||
"nodes": 1,
|
||||
"bs": 8,
|
||||
"steps": 1000,
|
||||
"layers": 12,
|
||||
"hidden_size": 768,
|
||||
"seq_length": 256,
|
||||
"heads": 12,
|
||||
"deepspeed": False,
|
||||
"json": "ds_config_func_bs8_zero2.json",
|
||||
}
|
||||
|
||||
succ = self.run_test(test_config, 0.01)
|
||||
|
@ -144,11 +216,12 @@ class GPT2FuncTestCase(BaseTestCase):
|
|||
print("\n")
|
||||
print("{0}: starting......".format(self.id()))
|
||||
|
||||
baseline_prefix = "gpt2_func_"
|
||||
prefix = "gpt2_partition_activation_"
|
||||
|
||||
# baseline run...
|
||||
test_config["deepspeed"] = False
|
||||
base_file = self.gen_output_name(test_config, prefix)
|
||||
base_file = self.gen_output_name(test_config, baseline_prefix)
|
||||
|
||||
# skip baseline run if it exists.
|
||||
if not self.has_loss_data(base_file):
|
||||
|
@ -159,7 +232,7 @@ class GPT2FuncTestCase(BaseTestCase):
|
|||
|
||||
# DeepSpeed run...
|
||||
test_config["deepspeed"] = True
|
||||
test_config["other_args"] = "--partition-activations"
|
||||
test_config["other_args"] = "--deepspeed-activation-checkpointing"
|
||||
print("{0}: DeepSpeed run.".format(self.id()))
|
||||
test_file = self.gen_output_name(test_config, prefix)
|
||||
self.run_gpt2_test(test_config, test_file)
|
||||
|
@ -217,10 +290,16 @@ class GPT2FuncTestCase(BaseTestCase):
|
|||
|
||||
def suite():
|
||||
suite = unittest.TestSuite()
|
||||
suite.addTest(GPT2FuncTestCase('test_mp1_gpu1_node1'))
|
||||
suite.addTest(GPT2FuncTestCase('test_mp1_gpu2_node1'))
|
||||
suite.addTest(GPT2FuncTestCase('test_mp2_gpu4_node1'))
|
||||
suite.addTest(GPT2FuncTestCase('test_mp4_gpu4_node1'))
|
||||
suite.addTest(GPT2FuncTestCase('test_mp1_gpu1_node1_zero1'))
|
||||
suite.addTest(GPT2FuncTestCase('test_mp1_gpu2_node1_zero1'))
|
||||
suite.addTest(GPT2FuncTestCase('test_mp2_gpu4_node1_zero1'))
|
||||
suite.addTest(GPT2FuncTestCase('test_mp4_gpu4_node1_zero1'))
|
||||
|
||||
suite.addTest(GPT2FuncTestCase('test_mp1_gpu1_node1_zero2'))
|
||||
suite.addTest(GPT2FuncTestCase('test_mp1_gpu2_node1_zero2'))
|
||||
suite.addTest(GPT2FuncTestCase('test_mp2_gpu4_node1_zero2'))
|
||||
suite.addTest(GPT2FuncTestCase('test_mp4_gpu4_node1_zero2'))
|
||||
|
||||
suite.addTest(GPT2FuncTestCase('test_optimizer_scheduler'))
|
||||
return suite
|
||||
|
||||
|
|
|
@ -29,14 +29,16 @@ def pytest_hack(runner_result):
|
|||
assert runner_result.wasSuccessful() # fail the test
|
||||
|
||||
|
||||
def test_run():
|
||||
#def test_megatron():
|
||||
# runner = unittest.TextTestRunner(failfast=True)
|
||||
# pytest_hack(runner.run(Megatron_GPT2.suite()))
|
||||
#
|
||||
#
|
||||
#def test_megatron_checkpoint():
|
||||
# runner = unittest.TextTestRunner(failfast=True)
|
||||
# pytest_hack(runner.run(Megatron_GPT2.checkpoint_suite()))
|
||||
|
||||
|
||||
def test_squad():
|
||||
runner = unittest.TextTestRunner(failfast=True)
|
||||
|
||||
# Add test suites here.
|
||||
pytest_hack(runner.run(Megatron_GPT2.suite()))
|
||||
pytest_hack(runner.run(Megatron_GPT2.checkpoint_suite()))
|
||||
pytest_hack(runner.run(BingBertSquad.suite()))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_run()
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
import os
|
||||
import json
|
||||
import argparse
|
||||
import torch
|
||||
import deepspeed
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
|
||||
class SimpleModel(torch.nn.Module):
|
||||
def __init__(self, hidden_dim, empty_grad=False):
|
||||
super(SimpleModel, self).__init__()
|
||||
self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
|
||||
if empty_grad:
|
||||
self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim)])
|
||||
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, x, y):
|
||||
hidden_dim = x
|
||||
hidden_dim = self.linear(hidden_dim)
|
||||
return self.cross_entropy_loss(hidden_dim, y)
|
||||
|
||||
|
||||
def create_config_from_dict(tmpdir, config_dict):
|
||||
config_path = os.path.join(tmpdir, 'temp_config.json')
|
||||
with open(config_path, 'w') as fd:
|
||||
json.dump(config_dict, fd)
|
||||
return config_path
|
||||
|
||||
|
||||
def get_data_loader(model, total_samples, hidden_dim, device):
|
||||
batch_size = model.train_micro_batch_size_per_gpu()
|
||||
train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=torch.half)
|
||||
train_label = torch.empty(total_samples,
|
||||
dtype=torch.long,
|
||||
device=device).random_(hidden_dim)
|
||||
train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
|
||||
sampler = DistributedSampler(train_dataset)
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler)
|
||||
return train_loader
|
||||
|
||||
|
||||
def get_args(tmpdir, config_dict):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--local_rank", type=int, default=0)
|
||||
parser.add_argument('--zero', type=int, default=0)
|
||||
args = parser.parse_args() #args=''
|
||||
|
||||
config_dict["zero_optimization"]["stage"] = args.zero
|
||||
print('config_dict["zero_optimization"]', config_dict["zero_optimization"])
|
||||
config_path = create_config_from_dict(tmpdir, config_dict)
|
||||
|
||||
args.deepspeed_config = config_path
|
||||
return args
|
||||
|
||||
|
||||
def print0(msg):
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(msg, flush=True)
|
||||
|
||||
|
||||
rank = int(os.environ['RANK'])
|
||||
print('seed:', 2222 + rank)
|
||||
torch.random.manual_seed(2222 + rank)
|
||||
|
||||
config_dict = {
|
||||
"train_batch_size": 8,
|
||||
"steps_per_print": 1,
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"params": {
|
||||
"lr": 0.00015,
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 15
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 0,
|
||||
"reduce_bucket_size": 20
|
||||
}
|
||||
}
|
||||
# "initial_scale_power": 15
|
||||
args = get_args('/tmp/', config_dict)
|
||||
hidden_dim = 4
|
||||
|
||||
model = SimpleModel(hidden_dim, empty_grad=False)
|
||||
|
||||
model, _, _,_ = deepspeed.initialize(args=args,
|
||||
model=model,
|
||||
model_parameters=model.parameters(),
|
||||
dist_init_required=True)
|
||||
|
||||
|
||||
def print_params(tag, model):
|
||||
if torch.distributed.get_rank() == 0:
|
||||
for n, p in model.named_parameters():
|
||||
print0("{} {}:{}".format(tag, n, p))
|
||||
|
||||
|
||||
data_loader = get_data_loader(model=model,
|
||||
total_samples=1000,
|
||||
hidden_dim=hidden_dim,
|
||||
device=model.device)
|
||||
#print_params('pre-train', model)
|
||||
for n, batch in enumerate(data_loader):
|
||||
loss = model(batch[0], batch[1])
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print("LOSS:", loss.item())
|
||||
model.backward(loss)
|
||||
model.step()
|
||||
#print_params('step={}'.format(n), model)
|
||||
if n == 5: break
|
|
@ -8,7 +8,7 @@ from torch.multiprocessing import Process
|
|||
import pytest
|
||||
|
||||
# Worker timeout *after* the first worker has completed.
|
||||
DEEPSPEED_UNIT_WORKER_TIMEOUT = 10
|
||||
DEEPSPEED_UNIT_WORKER_TIMEOUT = 120
|
||||
|
||||
|
||||
def distributed_test(world_size=2, backend='nccl'):
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
import deepspeed
|
||||
from deepspeed.pt.deepspeed_zero_optimizer import FP16_DeepSpeedZeroOptimizer
|
||||
from deepspeed.pt.zero_optimizer_stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
|
||||
|
||||
from deepspeed.pt.fp16_optimizer import FP16_Optimizer
|
||||
from deepspeed.pt.fp16_unfused_optimizer import FP16_UnfusedOptimizer
|
||||
|
@ -9,6 +10,7 @@ import argparse
|
|||
import pytest
|
||||
import json
|
||||
import os
|
||||
import numbers
|
||||
from common import distributed_test
|
||||
from simple_model import SimpleModel, random_dataloader, args_from_dict
|
||||
|
||||
|
@ -22,21 +24,6 @@ def compare_deepspeed_states(saved_model, loaded_model):
|
|||
assert saved_model.global_steps == loaded_model.global_steps
|
||||
|
||||
|
||||
def compare_lr_scheduler_states(saved_model, loaded_model):
|
||||
if saved_model.lr_scheduler is None:
|
||||
assert loaded_model.lr_scheduler is None
|
||||
return
|
||||
|
||||
saved = saved_model.lr_scheduler.state_dict()
|
||||
loaded = loaded_model.lr_scheduler.state_dict()
|
||||
assert sorted(saved.keys()) == sorted(loaded.keys())
|
||||
for key in saved.keys():
|
||||
if isinstance(saved[key], torch.Tensor):
|
||||
assert torch.equal(saved[key], loaded[key])
|
||||
else:
|
||||
assert saved[key] == loaded[key]
|
||||
|
||||
|
||||
def compare_model_states(saved_model, loaded_model):
|
||||
compare_deepspeed_states(saved_model, loaded_model)
|
||||
|
||||
|
@ -47,6 +34,11 @@ def compare_model_states(saved_model, loaded_model):
|
|||
for p0, p1 in zip(saved_model.optimizer.single_partition_of_fp32_groups, loaded_model.optimizer.single_partition_of_fp32_groups):
|
||||
assert torch.allclose(p0,p1,atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
|
||||
|
||||
elif isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer_Stage1):
|
||||
for partition0, partition1 in zip(saved_model.optimizer.local_sub_partitions_of_fp32_groups, loaded_model.optimizer.local_sub_partitions_of_fp32_groups):
|
||||
for p0, p1 in zip(partition0, partition1):
|
||||
assert torch.allclose(p0,p1,atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
|
||||
|
||||
elif isinstance(saved_model.optimizer, FP16_Optimizer):
|
||||
for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat):
|
||||
assert torch.allclose(p0,p1,atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
|
||||
|
@ -61,10 +53,6 @@ def compare_model_states(saved_model, loaded_model):
|
|||
|
||||
|
||||
def compare_optimizer_states(saved_model, loaded_model, hidden_dim):
|
||||
compare_model_states(saved_model, loaded_model)
|
||||
|
||||
assert hasattr(loaded_model, 'optimizer')
|
||||
|
||||
for state0, state1 in zip(saved_model.optimizer.optimizer.state.values(),
|
||||
loaded_model.optimizer.optimizer.state.values()):
|
||||
for s0, s1 in zip(state0.values(), state1.values()):
|
||||
|
@ -74,11 +62,35 @@ def compare_optimizer_states(saved_model, loaded_model, hidden_dim):
|
|||
assert s0 == s1
|
||||
|
||||
|
||||
def checkpoint_correctness_verification(save_folder,
|
||||
args,
|
||||
def compare_lr_scheduler_states(saved_model, loaded_model):
|
||||
assert hasattr(saved_model, 'lr_scheduler')
|
||||
assert hasattr(loaded_model, 'lr_scheduler')
|
||||
|
||||
saved_scheduler = saved_model.lr_scheduler
|
||||
loaded_scheduler = loaded_model.lr_scheduler
|
||||
|
||||
assert hasattr(saved_scheduler, 'state_dict')
|
||||
assert hasattr(loaded_scheduler, 'state_dict')
|
||||
|
||||
saved_sd = saved_scheduler.state_dict()
|
||||
loaded_sd = loaded_scheduler.state_dict()
|
||||
|
||||
print(f"saved_sd = {saved_sd}")
|
||||
print(f"loaded_sd = {loaded_sd}")
|
||||
|
||||
assert saved_sd.keys() == loaded_sd.keys()
|
||||
|
||||
for state0, state1 in zip(saved_sd.values(), loaded_sd.values()):
|
||||
if isinstance(state0, numbers.Number) and isinstance(state1, numbers.Number):
|
||||
assert state0 == state1
|
||||
|
||||
|
||||
def checkpoint_correctness_verification(args,
|
||||
model,
|
||||
hidden_dim,
|
||||
load_optimizer_states=True):
|
||||
tmpdir,
|
||||
load_optimizer_states=False,
|
||||
load_lr_scheduler_states=False):
|
||||
|
||||
ds_model, _, _,_ = deepspeed.initialize(args=args,
|
||||
model=model,
|
||||
|
@ -94,6 +106,7 @@ def checkpoint_correctness_verification(save_folder,
|
|||
|
||||
trained_model = ds_model
|
||||
|
||||
save_folder = os.path.join(tmpdir, 'saved_checkpoint')
|
||||
save_tag = '1'
|
||||
|
||||
trained_model.save_checkpoint(save_folder, save_tag)
|
||||
|
@ -104,14 +117,16 @@ def checkpoint_correctness_verification(save_folder,
|
|||
|
||||
loaded_model.load_checkpoint(save_folder,
|
||||
save_tag,
|
||||
load_optimizer_states=load_optimizer_states)
|
||||
load_optimizer_states=load_optimizer_states,
|
||||
load_lr_scheduler_states=load_lr_scheduler_states)
|
||||
|
||||
compare_lr_scheduler_states(trained_model, loaded_model)
|
||||
compare_model_states(trained_model, loaded_model)
|
||||
|
||||
if load_optimizer_states:
|
||||
compare_optimizer_states(trained_model, loaded_model, hidden_dim)
|
||||
else:
|
||||
compare_model_states(trained_model, loaded_model)
|
||||
|
||||
if load_lr_scheduler_states:
|
||||
compare_lr_scheduler_states(trained_model, loaded_model)
|
||||
|
||||
|
||||
def test_checkpoint_unfused_optimizer(tmpdir):
|
||||
|
@ -156,10 +171,10 @@ def test_checkpoint_unfused_optimizer(tmpdir):
|
|||
model,
|
||||
hidden_dim,
|
||||
load_optimizer_states):
|
||||
checkpoint_correctness_verification(tmpdir,
|
||||
args,
|
||||
checkpoint_correctness_verification(args,
|
||||
model,
|
||||
hidden_dim,
|
||||
tmpdir,
|
||||
load_optimizer_states=load_optimizer_states)
|
||||
|
||||
_test_checkpoint_unfused_optimizer(args=args,
|
||||
|
@ -198,10 +213,10 @@ def test_checkpoint_fused_optimizer(tmpdir):
|
|||
|
||||
@distributed_test(world_size=[2])
|
||||
def _test_checkpoint_fused_optimizer(args, model, hidden_dim, load_optimizer_states):
|
||||
checkpoint_correctness_verification(tmpdir,
|
||||
args,
|
||||
checkpoint_correctness_verification(args,
|
||||
model,
|
||||
hidden_dim,
|
||||
tmpdir,
|
||||
load_optimizer_states=load_optimizer_states)
|
||||
|
||||
_test_checkpoint_fused_optimizer(args=args,
|
||||
|
@ -214,7 +229,8 @@ def test_checkpoint_fused_optimizer(tmpdir):
|
|||
load_optimizer_states=False)
|
||||
|
||||
|
||||
def test_checkpoint_zero_optimizer(tmpdir):
|
||||
@pytest.mark.parametrize("zero_stage", [1, 2])
|
||||
def test_checkpoint_zero_optimizer(tmpdir, zero_stage):
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
|
@ -231,7 +247,9 @@ def test_checkpoint_zero_optimizer(tmpdir):
|
|||
"fp16": {
|
||||
"enabled": True
|
||||
},
|
||||
"zero_optimization": True
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage
|
||||
},
|
||||
}
|
||||
args = args_from_dict(tmpdir, config_dict)
|
||||
hidden_dim = 10
|
||||
|
@ -240,17 +258,165 @@ def test_checkpoint_zero_optimizer(tmpdir):
|
|||
|
||||
@distributed_test(world_size=[2])
|
||||
def _test_checkpoint_zero_optimizer(args, model, hidden_dim, load_optimizer_states):
|
||||
checkpoint_correctness_verification(tmpdir,
|
||||
args,
|
||||
checkpoint_correctness_verification(args,
|
||||
model,
|
||||
hidden_dim,
|
||||
tmpdir,
|
||||
load_optimizer_states=load_optimizer_states)
|
||||
|
||||
_test_checkpoint_zero_optimizer(args=args,
|
||||
model=model,
|
||||
hidden_dim=hidden_dim,
|
||||
load_optimizer_states=True)
|
||||
_test_checkpoint_zero_optimizer(args=args,
|
||||
model=model,
|
||||
hidden_dim=hidden_dim,
|
||||
load_optimizer_states=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("zero_stage", [1, 2])
|
||||
def test_checkpoint_zero_no_optimizer(tmpdir, zero_stage):
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"params": {
|
||||
"lr": 0.00015,
|
||||
"betas": [0.8,
|
||||
0.999],
|
||||
"eps": 1e-8,
|
||||
"weight_decay": 3e-7
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage
|
||||
},
|
||||
}
|
||||
args = args_from_dict(tmpdir, config_dict)
|
||||
hidden_dim = 10
|
||||
|
||||
model = SimpleModel(hidden_dim, empty_grad=False)
|
||||
|
||||
@distributed_test(world_size=[2])
|
||||
def _test_checkpoint_zero_no_optimizer(args,
|
||||
model,
|
||||
hidden_dim,
|
||||
load_optimizer_states):
|
||||
checkpoint_correctness_verification(args,
|
||||
model,
|
||||
hidden_dim,
|
||||
tmpdir,
|
||||
load_optimizer_states=load_optimizer_states)
|
||||
|
||||
_test_checkpoint_zero_no_optimizer(args=args,
|
||||
model=model,
|
||||
hidden_dim=hidden_dim,
|
||||
load_optimizer_states=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("zero_stage", [0, 1, 2])
|
||||
def test_checkpoint_lr_scheduler(tmpdir, zero_stage):
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"params": {
|
||||
"lr": 0.00015,
|
||||
"betas": [0.8,
|
||||
0.999],
|
||||
"eps": 1e-8,
|
||||
"weight_decay": 3e-7
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": 0,
|
||||
"warmup_max_lr": 0.001,
|
||||
"warmup_num_steps": 1000
|
||||
}
|
||||
}
|
||||
}
|
||||
args = args_from_dict(tmpdir, config_dict)
|
||||
hidden_dim = 10
|
||||
|
||||
model = SimpleModel(hidden_dim, empty_grad=False)
|
||||
|
||||
@distributed_test(world_size=[2])
|
||||
def _test_checkpoint_lr_scheduler(args,
|
||||
model,
|
||||
hidden_dim,
|
||||
load_optimizer_states,
|
||||
load_lr_scheduler_states):
|
||||
checkpoint_correctness_verification(
|
||||
args,
|
||||
model,
|
||||
hidden_dim,
|
||||
tmpdir,
|
||||
load_optimizer_states=load_optimizer_states,
|
||||
load_lr_scheduler_states=load_lr_scheduler_states)
|
||||
|
||||
_test_checkpoint_lr_scheduler(args=args,
|
||||
model=model,
|
||||
hidden_dim=hidden_dim,
|
||||
load_optimizer_states=False,
|
||||
load_lr_scheduler_states=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("zero_stage", [0, 1, 2])
|
||||
def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage):
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"params": {
|
||||
"lr": 1e-5
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": 0,
|
||||
"warmup_max_lr": 0.001,
|
||||
"warmup_num_steps": 1000
|
||||
}
|
||||
}
|
||||
}
|
||||
args = args_from_dict(tmpdir, config_dict)
|
||||
hidden_dim = 10
|
||||
|
||||
model = SimpleModel(hidden_dim, empty_grad=False)
|
||||
|
||||
@distributed_test(world_size=[2])
|
||||
def _test_checkpoint_no_lr_scheduler(args,
|
||||
model,
|
||||
hidden_dim,
|
||||
load_optimizer_states,
|
||||
load_lr_scheduler_states):
|
||||
checkpoint_correctness_verification(
|
||||
args,
|
||||
model,
|
||||
hidden_dim,
|
||||
tmpdir,
|
||||
load_optimizer_states=load_optimizer_states,
|
||||
load_lr_scheduler_states=load_lr_scheduler_states)
|
||||
|
||||
_test_checkpoint_no_lr_scheduler(args=args,
|
||||
model=model,
|
||||
hidden_dim=hidden_dim,
|
||||
load_optimizer_states=False,
|
||||
load_lr_scheduler_states=False)
|
||||
|
|
|
@ -17,3 +17,19 @@ def test_only_required_fields(tmpdir):
|
|||
assert run_cfg.train_batch_size == 64
|
||||
assert run_cfg.train_micro_batch_size_per_gpu == 64
|
||||
assert run_cfg.gradient_accumulation_steps == 1
|
||||
|
||||
|
||||
def test_config_duplicate_key(tmpdir):
|
||||
config_dict = '''
|
||||
{
|
||||
"train_batch_size": 24,
|
||||
"train_batch_size": 24,
|
||||
}
|
||||
'''
|
||||
config_path = os.path.join(tmpdir, 'temp_config.json')
|
||||
|
||||
with open(config_path, 'w') as jf:
|
||||
jf.write("%s" % config_dict)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
run_cfg = ds_config.DeepSpeedConfig(config_path)
|
||||
|
|
|
@ -144,7 +144,8 @@ def test_adamw_fp16_empty_grad(tmpdir):
|
|||
_test_adamw_fp16_empty_grad(args=args, model=model, hidden_dim=hidden_dim)
|
||||
|
||||
|
||||
def test_adam_fp16_onecycle_compatibility(tmpdir):
|
||||
@pytest.mark.parametrize("zero_stage", [0, 1, 2])
|
||||
def test_adam_fp16_zero_onecycle_compatibility(tmpdir, zero_stage):
|
||||
config_dict = {
|
||||
"train_batch_size": 1,
|
||||
"steps_per_print": 1,
|
||||
|
@ -171,59 +172,11 @@ def test_adam_fp16_onecycle_compatibility(tmpdir):
|
|||
"fp16": {
|
||||
"enabled": True
|
||||
},
|
||||
"zero_optimization": False
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage
|
||||
}
|
||||
}
|
||||
args = args_from_dict(tmpdir, config_dict)
|
||||
hidden_dim = 10
|
||||
|
||||
model = SimpleModel(hidden_dim, empty_grad=True)
|
||||
|
||||
@distributed_test(world_size=[1])
|
||||
def _test_adam_fp16_onecycle_compatibility(args, model, hidden_dim):
|
||||
model, _, _,_ = deepspeed.initialize(args=args,
|
||||
model=model,
|
||||
model_parameters=model.parameters())
|
||||
data_loader = random_dataloader(model=model,
|
||||
total_samples=50,
|
||||
hidden_dim=hidden_dim,
|
||||
device=model.device)
|
||||
for n, batch in enumerate(data_loader):
|
||||
loss = model(batch[0], batch[1])
|
||||
model.backward(loss)
|
||||
model.step()
|
||||
|
||||
_test_adam_fp16_onecycle_compatibility(args=args, model=model, hidden_dim=hidden_dim)
|
||||
|
||||
|
||||
def test_adam_fp16_zero_onecycle_compatibility(tmpdir):
|
||||
config_dict = {
|
||||
"train_batch_size": 1,
|
||||
"steps_per_print": 1,
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"params": {
|
||||
"lr": 0.00015
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "OneCycle",
|
||||
"params": {
|
||||
"cycle_first_step_size": 16000,
|
||||
"cycle_first_stair_count": 8000,
|
||||
"decay_step_size": 16000,
|
||||
"cycle_min_lr": 1e-06,
|
||||
"cycle_max_lr": 3e-05,
|
||||
"decay_lr_rate": 1e-07,
|
||||
"cycle_min_mom": 0.85,
|
||||
"cycle_max_mom": 0.99,
|
||||
"decay_mom_rate": 0.0
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
},
|
||||
"zero_optimization": True
|
||||
}
|
||||
args = args_from_dict(tmpdir, config_dict)
|
||||
hidden_dim = 10
|
||||
|
||||
|
@ -248,7 +201,53 @@ def test_adam_fp16_zero_onecycle_compatibility(tmpdir):
|
|||
hidden_dim=hidden_dim)
|
||||
|
||||
|
||||
def test_zero_static_scale(tmpdir):
|
||||
@pytest.mark.parametrize("zero_stage", [1, 2])
|
||||
def test_zero_static_scale(tmpdir, zero_stage):
|
||||
config_dict = {
|
||||
"train_batch_size": 4,
|
||||
"steps_per_print": 1,
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"params": {
|
||||
"lr": 0.00015
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"loss_scale": 138.
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage
|
||||
}
|
||||
}
|
||||
args = args_from_dict(tmpdir, config_dict)
|
||||
|
||||
@distributed_test(world_size=2)
|
||||
def _test_zero_static_scale(args):
|
||||
hidden_dim = 10
|
||||
model = SimpleModel(hidden_dim, empty_grad=True)
|
||||
model, optim, _,_ = deepspeed.initialize(args=args,
|
||||
model=model,
|
||||
model_parameters=model.parameters())
|
||||
|
||||
# Ensure the static scaler is configured.
|
||||
assert optim.dynamic_loss_scale == False
|
||||
assert optim.loss_scaler.loss_scale == 138.
|
||||
|
||||
# Now make sure things work..
|
||||
data_loader = random_dataloader(model=model,
|
||||
total_samples=10,
|
||||
hidden_dim=hidden_dim,
|
||||
device=model.device)
|
||||
for n, batch in enumerate(data_loader):
|
||||
loss = model(batch[0], batch[1])
|
||||
model.backward(loss)
|
||||
model.step()
|
||||
|
||||
_test_zero_static_scale(args)
|
||||
|
||||
|
||||
def test_zero_static_scale_deprecated_format(tmpdir):
|
||||
config_dict = {
|
||||
"train_batch_size": 4,
|
||||
"steps_per_print": 1,
|
||||
|
@ -291,14 +290,17 @@ def test_zero_static_scale(tmpdir):
|
|||
_test_zero_static_scale(args)
|
||||
|
||||
|
||||
def test_zero_allow_untested_optimizer(tmpdir):
|
||||
@pytest.mark.parametrize("zero_stage", [1, 2])
|
||||
def test_zero_allow_untested_optimizer(tmpdir, zero_stage):
|
||||
config_dict = {
|
||||
"train_batch_size": 4,
|
||||
"steps_per_print": 1,
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
},
|
||||
"zero_optimization": True,
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage
|
||||
},
|
||||
"zero_allow_untested_optimizer": False
|
||||
}
|
||||
args = args_from_dict(tmpdir, config_dict)
|
||||
|
@ -317,31 +319,34 @@ def test_zero_allow_untested_optimizer(tmpdir):
|
|||
_test_zero_allow_untested_optimizer(args)
|
||||
|
||||
|
||||
def test_zero_empty_partition(tmpdir):
|
||||
config_dict = {
|
||||
"train_batch_size": 3,
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"params": {
|
||||
"lr": 0.00015
|
||||
}
|
||||
},
|
||||
"zero_optimization": True
|
||||
}
|
||||
args = args_from_dict(tmpdir, config_dict)
|
||||
# @pytest.mark.parametrize("zero_stage", [1])
|
||||
# def test_zero_empty_partition(tmpdir, zero_stage):
|
||||
# config_dict = {
|
||||
# "train_batch_size": 3,
|
||||
# "fp16": {
|
||||
# "enabled": True
|
||||
# },
|
||||
# "optimizer": {
|
||||
# "type": "Adam",
|
||||
# "params": {
|
||||
# "lr": 0.00015
|
||||
# }
|
||||
# },
|
||||
# "zero_optimization": {
|
||||
# "stage": zero_stage
|
||||
# }
|
||||
# }
|
||||
# args = args_from_dict(tmpdir, config_dict)
|
||||
|
||||
@distributed_test(world_size=[3])
|
||||
def _test_zero_empty_partition(args):
|
||||
hidden_dim = 1
|
||||
model = SimpleModel(hidden_dim)
|
||||
# Ensure model has 2 parameters, to cause empty partition with DP=3
|
||||
assert len(list(model.parameters())) == 2
|
||||
model, _, _, _ = deepspeed.initialize(args=args,
|
||||
model=model,
|
||||
model_parameters=model.parameters())
|
||||
model.step()
|
||||
# @distributed_test(world_size=[3])
|
||||
# def _test_zero_empty_partition(args):
|
||||
# hidden_dim = 1
|
||||
# model = SimpleModel(hidden_dim)
|
||||
# # Ensure model has 2 parameters, to cause empty partition with DP=3
|
||||
# assert len(list(model.parameters())) == 2
|
||||
# model, _, _, _ = deepspeed.initialize(args=args,
|
||||
# model=model,
|
||||
# model_parameters=model.parameters())
|
||||
# model.step()
|
||||
|
||||
_test_zero_empty_partition(args)
|
||||
# _test_zero_empty_partition(args)
|
||||
|
|
|
@ -73,7 +73,7 @@ def test_two_output_model(tmpdir):
|
|||
|
||||
summed_loss = sum(loss_tuple)
|
||||
scaled_loss = model.backward(summed_loss)
|
||||
expected_scaled_loss = summed_loss / gradient_accumulation_steps
|
||||
expected_scaled_loss = summed_loss.float() / gradient_accumulation_steps
|
||||
assert scaled_loss.item() == approx(expected_scaled_loss.item())
|
||||
|
||||
model.step()
|
||||
|
@ -131,7 +131,7 @@ def test_three_output_model(tmpdir):
|
|||
|
||||
summed_loss = sum(loss_tuple)
|
||||
scaled_loss = model.backward(summed_loss)
|
||||
expected_scaled_loss = summed_loss / gradient_accumulation_steps
|
||||
expected_scaled_loss = summed_loss.float() / gradient_accumulation_steps
|
||||
assert scaled_loss.item() == approx(expected_scaled_loss.item())
|
||||
|
||||
model.step()
|
||||
|
|
Загрузка…
Ссылка в новой задаче