Updates for ZeRO stage 2 + ZeRO stage 1 w. RS

Co-authored-by: Tunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Samyam Rajbhandari <samyamr@microsoft.com>
Co-authored-by: Shaden Smith <ShadenTSmith@gmail.com>
Co-authored-by: Elton Zheng <eltonz@microsoft.com>
Co-authored-by: Shaden Smith <Shaden.Smith@microsoft.com>
Co-authored-by: yuxionghe <yuxhe@microsoft.com>
Co-authored-by: Arash Ashari <arashari@microsoft.com>
This commit is contained in:
Jeff Rasley 2020-05-19 01:00:53 -07:00 коммит произвёл GitHub
Родитель c61e23b4b1
Коммит f2ac7eafd5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
65 изменённых файлов: 4703 добавлений и 1083 удалений

1
.gitignore поставляемый
Просмотреть файл

@ -2,6 +2,7 @@
.idea/
*~
*.swp
*.log
deepspeed/git_version_info.py
# Build + installation data

@ -1 +1 @@
Subproject commit 9e2c735f5aabe48395c03a276fa7a0c51f6d3025
Subproject commit 274787a189b265814ed75dd5ddeae2dce026ea88

106
README.md
Просмотреть файл

@ -1,23 +1,30 @@
[![Build Status](https://dev.azure.com/DeepSpeedMSFT/DeepSpeed/_apis/build/status/microsoft.DeepSpeed?branchName=master)](https://dev.azure.com/DeepSpeedMSFT/DeepSpeed/_build/latest?definitionId=1&branchName=master)
[![License MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/Microsoft/DeepSpeed/blob/master/LICENSE)
[DeepSpeed](https://www.deepspeed.ai/) is a deep learning optimization library that makes distributed training easy,
efficient, and effective.
[DeepSpeed](https://www.deepspeed.ai/) is a deep learning optimization
library that makes distributed training easy, efficient, and effective.
<p align="center"><i><b>10x Larger Models</b></i></p>
<p align="center"><i><b>5x Faster Training</b></i></p>
<p align="center"><i><b>10x Faster Training</b></i></p>
<p align="center"><i><b>Minimal Code Change</b></i></p>
DeepSpeed can train deep learning models with over a hundred billion parameters on current
generation of GPU clusters, while achieving over 5x in system performance
generation of GPU clusters, while achieving over 10x in system performance
compared to the state-of-art. Early adopters of DeepSpeed have already produced
a language model (LM) with over 17B parameters called
[Turing-NLG](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft),
establishing a new SOTA in the LM category.
**_For further documentation, tutorials, and technical deep-dives please see [deepspeed.ai](https://www.deepspeed.ai/)!_**
# News
* [Turing-NLG: A 17-billion-parameter language model by Microsoft](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/)
* [ZeRO & DeepSpeed: New system optimizations enable training models with over 100 billion parameters](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/)
* [2020/05/19] [ZeRO-2 empowers training models as large as 170 billion parameters up to 10x faster compared to state-of-the-art](https://www.deepspeed.ai/news/2020/05/19/zero-stage2.html)
<span style="color:dodgerblue">**[_NEW_]**</span>
* [2020/05/19] [DeepSpeed optimizes transformer kernels to achieve worlds fastest BERT training record: 44 minutes on 1024 NVIDIA V100 GPUs](https://www.deepspeed.ai/news/2020/05/19/bert-record.html)
<span style="color:dodgerblue">**[_NEW_]**</span>
* [2020/02/13] [Turing-NLG: A 17-billion-parameter language model by Microsoft](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/)
* [2020/02/13] [ZeRO & DeepSpeed: New system optimizations enable training models with over 100 billion parameters](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/)
# Table of Contents
@ -39,93 +46,6 @@ a large model easily runs out of memory with pure data parallelism and it is
difficult to use model parallelism. DeepSpeed addresses these challenges to
accelerate model development *and* training.
## Distributed, Effective, and Efficient Training with Ease
The DeepSpeed API is a lightweight wrapper on [PyTorch](https://pytorch.org/). This
means that you can use everything you love in PyTorch and without learning a new
platform. In addition, DeepSpeed manages all of the boilerplate state-of-the-art
training techniques, such as distributed training, mixed precision, gradient
accumulation, and checkpoints so that you can focus on your model development. Most
importantly, you can leverage the distinctive efficiency and effectiveness benefit of
DeepSpeed to boost speed and scale with just a few lines of code changes to your PyTorch
models.
## Speed
DeepSpeed achieves high performance and fast convergence through a combination of
efficiency optimizations on compute/communication/memory/IO and effectiveness
optimizations on advanced hyperparameter tuning and optimizers. For example:
* DeepSpeed trains BERT-large to parity in 14 hours using 64 GPUs (4 DGX-2 boxes) and in
3.7 hours using 256 GPUs (16 DGX-2 boxes).
**BERT-large Training Times**
| Devices | Source | Training Time (hours) |
| ------------- | --------- | ---------------------:|
| 64 TPUs | Google | 96 |
| 64 V100 GPUs | DeepSpeed | **14** |
| 256 V100 GPUs | NVIDIA | 3.9 |
| 256 V100 GPUs | DeepSpeed | **3.7** |
*Read more*: [BERT pre-training tutorial](https://www.deepspeed.ai/tutorials/bert-pretraining/)
* DeepSpeed trains GPT2 (1.5 billion parameters) 3.75x faster than state-of-art, NVIDIA
Megatron on Azure GPUs.
*Read more*: [GPT tutorial](https://www.deepspeed.ai/tutorials/megatron/)
## Memory efficiency
DeepSpeed provides memory-efficient data parallelism and enables training models without
model parallelism. For example, DeepSpeed can train models with up to 6 billion parameters on
NVIDIA V100 GPUs with 32GB of device memory. In comparison, existing frameworks (e.g.,
PyTorch's Distributed Data Parallel) run out of memory with 1.5 billion parameter models.
DeepSpeed reduces the training memory footprint through a novel solution called Zero
Redundancy Optimizer (ZeRO). Unlike basic data parallelism where memory states are
replicated across data-parallel processes, ZeRO partitions model states to save
significant memory. The current implementation (stage 1 of ZeRO) reduces memory by up to
4x relative to the state-of-art. You can read more about ZeRO in our [paper](https://arxiv.org/abs/1910.02054).
With this impressive memory reduction, early adopters of DeepSpeed have already
produced a language model (LM) with over 17B parameters called
[Turing-NLG](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft),
establishing a new SOTA in the LM category.
## Scalability
DeepSpeed supports efficient data parallelism, model parallelism, and their
combination. ZeRO boosts the scaling capability and efficiency further.
* DeepSpeed provides system support to run models up to 100 billion parameters,
10x larger than the state-of-art (8 billion NVIDIA GPT, 11 billion Google T5).
* DeepSpeed can run large models more efficiently, up to 6x faster for models with
various sizes spanning 1.5B to 100B. More specifically, the data parallelism powered by ZeRO
is complementary and can be combined with different types of model parallelism. It allows
DeepSpeed to fit models using lower degree of model parallelism and higher batch size, offering
significant performance gains compared to using model parallelism alone.
*Read more*: [technical report](https://arxiv.org/abs/1910.02054)
and [GPT tutorial](https://www.deepspeed.ai/tutorials/megatron/)
![DeepSpeed-vs-Megatron](./docs/assets/images/DeepSpeed-vs-Megatron.png)
<p align="center">
<em>The figure depicts system throughput improvements of DeepSpeed (combining ZeRO-powered data parallelism with model parallelism of NVIDIA Megatron-LM) over using Megatron-LM alone.</em>
</p>
## Fast convergence for effectiveness
DeepSpeed supports advanced hyperparameter tuning and large batch size
optimizers such as [LAMB](https://arxiv.org/abs/1904.00962). These improve the
effectiveness of model training and reduce the number of samples required to
convergence to desired accuracy.
*Read more*: [Tuning tutorial](https://www.deepspeed.ai/tutorials/1Cycle/) and [BERT pre-training tutorial](https://www.deepspeed.ai/tutorials/bert-pretraining/)
## Usability
Only a few lines of code changes are needed to enable a PyTorch model to use DeepSpeed and ZeRO. Compared to current model parallelism libraries, DeepSpeed does not require a code redesign or model refactoring. It also does not put limitations on model dimensions (such as number of attention heads, hidden sizes, and others), batch size, or any other training parameters. For models of up to six billion parameters, you can use ZeRO-powered data parallelism conveniently without requiring model parallelism, while in contrast, standard data parallelism will run out of memory for models with more than 1.3 billion parameters. In addition, DeepSpeed conveniently supports flexible combination of ZeRO-powered data parallelism with custom model parallelisms, such as tensor slicing of NVIDIA's Megatron-LM.
# Features
Below we provide a brief feature list, see our detailed [feature

Просмотреть файл

@ -35,11 +35,6 @@ jobs:
pre-commit run --all-files
displayName: 'Formatting checks'
- script: |
pip install --user pylint
pylint --exit-zero deepspeed/
displayName: 'Code linter'
- script: |
pytest --forked --verbose tests/unit/
displayName: 'Unit tests'

Просмотреть файл

@ -6,6 +6,8 @@ from deepspeed.pt.deepspeed_light import DeepSpeedLight
from deepspeed.pt.deepspeed_light import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from deepspeed.pt.deepspeed_lr_schedules import add_tuning_arguments
import deepspeed.pt.deepspeed_checkpointing as checkpointing
try:
from deepspeed.git_version_info import git_hash, git_branch
except ImportError:
@ -14,7 +16,7 @@ except ImportError:
# Export version information
__version_major__ = 0
__version_minor__ = 1
__version_minor__ = 2
__version_patch__ = 0
__version__ = '.'.join(
map(str,
@ -33,7 +35,8 @@ def initialize(args,
lr_scheduler=None,
mpu=None,
dist_init_required=None,
collate_fn=None):
collate_fn=None,
config_params=None):
"""Initialize the DeepSpeed Engine.
Arguments:
@ -91,7 +94,8 @@ def initialize(args,
lr_scheduler=lr_scheduler,
mpu=mpu,
dist_init_required=dist_init_required,
collate_fn=collate_fn)
collate_fn=collate_fn,
config_params=config_params)
return_items = [
engine,

Просмотреть файл

@ -0,0 +1,724 @@
'''
Copyright (c) Microsoft Corporation
Licensed under the MIT license.
Use to partition the activations stored for backward propagation
Therefore reduces the memory consumption
Also implements CPU checkpointing and contiguous memory checkpointing
Reduces memory consumption and memory fragmentation
Code for rng checkpointing taken from NVIDIA Megatron-LM mpu/random.py
b886b7bb972afe72bac0f5de4f42a4a7bae8ebef
'''
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
import contextlib
import torch.distributed as dist
import torch
from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager
from deepspeed.pt.deepspeed_timer import SynchronizedWallClockTimer as Timers
import torch.distributed as dist
from deepspeed.pt.deepspeed_config import DeepSpeedConfig
#DeepSpeed Checkpointing Enabled or Disabled
deepspeed_checkpointing_enabled = False
#MP parameters
mpu = None
mp_rank = None
mp_size = None
mp_group = None
#Model Parameters
num_layers = None
#Checkpointing buffers
contiguous_data_buffers = []
data_offsets = []
contiguous_size_buffers = []
size_offsets = []
timers = None
#optimization flags
PARTITION_ACTIVATIONS = False
PA_TO_CPU = False
CONTIGUOUS_CHECKPOINTING = False
SYNCHRONIZE = False
PROFILE_TIME = False
def see_memory_usage(message, force=False):
#return
if not force:
return
#dist.barrier()
if dist.get_rank() == 0:
print(message)
print("Memory Allocated ",
torch.cuda.memory_allocated() / (1024 * 1024 * 1024),
"GigaBytes")
print("Max Memory Allocated ",
torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),
"GigaBytes")
print("Cache Allocated ",
torch.cuda.memory_cached() / (1024 * 1024 * 1024),
"GigaBytes")
print("Max cache Allocated ",
torch.cuda.max_memory_cached() / (1024 * 1024 * 1024),
"GigaBytes")
print(" ")
#input("Press Any Key To Continue ..")
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
transport_stream = None
cuda_device = None
def detach_variable(inputs, device=None):
if isinstance(inputs, tuple):
out = []
for inp in inputs:
if not isinstance(inp, torch.Tensor):
out.append(inp)
continue
requires_grad = inp.requires_grad
if device is not None:
x = inp.to(device=device)
else:
x = inp
x = x.detach()
x.requires_grad = requires_grad
out.append(x)
return tuple(out)
else:
raise RuntimeError(
"Only tuple of tensors is supported. Got Unsupported input type: ",
type(inputs).__name__)
def _set_cuda_rng_state(new_state, device=-1):
"""Sets the random number generator state of the current GPU.
Argumentss:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
# older PyTorch
def cb():
with device_ctx_manager(device):
_C._cuda_setRNGState(new_state)
else:
# newer PyTorch
if device == -1:
device = torch.device('cuda')
elif isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device('cuda', device)
def cb():
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
default_generator.set_state(new_state)
_lazy_call(cb)
class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
Using the `add` method, a cuda rng state is initialized based on
the input `seed` and is assigned to `name`. Later, by forking the
rng state, we can perform operations and return to our starting
cuda state.
"""
def __init__(self):
# Map from a string name to the cuda rng state.
self.states_ = {}
# Seeds are just for book keeping and ensure no seed is set twice.
self.seeds_ = set()
def reset(self):
"""Set to the initial state (no tracker)."""
self.states_ = {}
self.seeds_ = set()
def get_states(self):
"""Get rng states. Copy the dictionary so we have direct
pointers to the states, not just a pointer to the dictionary."""
states = {}
for name in self.states_:
states[name] = self.states_[name]
return states
def set_states(self, states):
"""Set the rng states. For efficiency purposes, we do not check
the size of seed for compatibility."""
self.states_ = states
def add(self, name, seed):
"""Track the rng state."""
# Check seed is not already used.
if seed in self.seeds_:
raise Exception('seed {} already exists'.format(seed))
self.seeds_.add(seed)
# Check that state is not already defined.
if name in self.states_:
raise Exception('cuda rng state {} already exists'.format(name))
# Get the current rng state.
orig_rng_state = torch.cuda.get_rng_state()
# Set the new state and store it.
torch.cuda.manual_seed(seed)
self.states_[name] = torch.cuda.get_rng_state()
# Reset rng state to what it was.
_set_cuda_rng_state(orig_rng_state)
@contextlib.contextmanager
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
"""Fork the cuda rng state, perform operations, and exit with
the original state."""
# Check if we have added the state
if name not in self.states_:
raise Exception('cuda rng state {} is not added'.format(name))
# Store current rng state.
orig_cuda_rng_state = torch.cuda.get_rng_state()
# Set rng state to the desired one
_set_cuda_rng_state(self.states_[name])
# Do the stuff we wanted to do.
try:
yield
finally:
# Update the current rng state for later use.
self.states_[name] = torch.cuda.get_rng_state()
# And set the state to the original state we started with.
_set_cuda_rng_state(orig_cuda_rng_state)
# RNG tracker object.
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
def get_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _CUDA_RNG_STATE_TRACKER
def model_parallel_cuda_manual_seed(seed):
"""Initialize model parallel cuda seed.
This function should be called after the model parallel is
initialized. Also, no torch.cuda.manual_seed should be called
after this function. Basically, this is replacement for that
function.
Two set of RNG states are tracked:
default state: This is for data parallelism and is the same among a
set of model parallel GPUs but different across
different model paralle groups. This is used for
example for dropout in the non-model-parallel regions.
model-parallel state: This state is different among a set of model
parallel GPUs, but the same across data parallel
groups. This is used for example for dropout in
model parallel regions.
"""
global mpu
# 2718 is just for fun and any POSITIVE value will work.
offset = seed + 2718
model_parallel_seed = offset + mpu.get_model_parallel_rank()
# Data parallel gets the original sedd.
data_parallel_seed = seed
if torch.distributed.get_rank() == 0:
print('> initializing model parallel cuda seeds on global rank {}, '
'model parallel rank {}, and data parallel rank {} with '
'model parallel seed: {} and data parallel seed: {}'.format(
torch.distributed.get_rank(),
mpu.get_model_parallel_rank(),
mpu.get_data_parallel_rank(),
model_parallel_seed,
data_parallel_seed),
flush=True)
_CUDA_RNG_STATE_TRACKER.reset()
# Set the default state.
torch.cuda.manual_seed(data_parallel_seed)
# and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed)
def get_partition_start(item):
global mp_rank, mp_size, mp_group
size = item.numel()
partition_size = size / mp_size
start = partition_size * mp_rank
return int(start)
def get_partition_size(item):
global mp_rank, mp_size, mp_group
size = item.numel()
assert size % mp_size == 0, "Doesn't handle if partition activation if item is not divisible by mp size"
partition_size = size / mp_size
return int(partition_size)
def get_full_inputs(tensors, device=None):
inputs = []
num_args = int(len(tensors) / 2)
for i in range(num_args - 1):
item = tensors[2 * i]
size = tensors[2 * i + 1]
partition_size = item.numel()
tensor_size = partition_size * mp_size
if device is not None:
flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=device)
else:
flat_tensor = torch.zeros([tensor_size],
dtype=item.dtype,
device=item.device)
partitions = []
for i in range(mp_size):
part_i = flat_tensor.narrow(0, partition_size * i, partition_size)
if i == mp_rank:
part_i.copy_(item)
partitions.append(part_i)
if mp_group is not None:
dist.all_gather(partitions, partitions[mp_rank], group=mp_group)
input_tensor = flat_tensor.view(list(size.numpy()))
item.data = input_tensor.data
inputs.append(item)
inputs.append(tensors[-2])
return tuple(inputs)
class CheckpointFunction(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with
two main changes:
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
2) the states in the model parallel tracker are also properly
tracked/set/reset.
3) Performance activation partitioning, contiguous memory optimization
4) CPU Checkpointing
5) Profile forward and backward functions
"""
@staticmethod
def forward(ctx, run_function, *args):
global mpu, timers, SYNCHRONIZE, PROFILE_TIME
if SYNCHRONIZE:
torch.cuda.synchronize()
if timers is None and PROFILE_TIME:
timers = Timers()
if PROFILE_TIME:
timers('forward').start()
ctx.run_function = run_function
global num_layers
global mp_rank, mp_size, mp_group
global contiguous_data_buffers, contiguous_size_buffers
global data_offsets, size_offsets
if mp_rank is None:
if mpu is not None:
mp_rank = mpu.get_model_parallel_rank()
mp_size = mpu.get_model_parallel_world_size()
mp_group = mpu.get_model_parallel_group()
else:
mp_rank = 0
mp_size = 1
mp_group = None
global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset
if cuda_device is None:
see_memory_usage("First Forward Begining", force=True)
if dist.get_rank() == 0:
print(f"Activation Checkpointing Information")
print(
f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {PA_TO_CPU}"
)
print(
f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers"
)
print(f"----Synchronization {SYNCHRONIZE}")
print(f"----Profiling {PROFILE_TIME}")
cuda_device = torch.cuda.current_device()
transport_stream = torch.cuda.Stream(device=cuda_device)
if PARTITION_ACTIVATIONS:
#inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]]
#inputs.append(args[-1])
inputs = []
for i, item in enumerate(args[:-1]):
partition_size = get_partition_size(item)
partition = item.detach().contiguous().view(-1).narrow(
0,
get_partition_start(item),
partition_size).clone()
if CONTIGUOUS_CHECKPOINTING:
buffer_device = torch.device(
'cpu') if PA_TO_CPU else partition.device
if i >= len(contiguous_data_buffers):
tensor_list = [
torch.tensor(()).new_empty([partition_size],
dtype=partition.dtype,
device=buffer_device)
for i in range(num_layers)
]
contiguous_data_buffers.append(tensor_list)
data_offsets.append(0)
elif contiguous_data_buffers[i] is None:
tensor_list = [
torch.tensor(()).new_empty([partition_size],
dtype=partition.dtype,
device=buffer_device)
for i in range(num_layers)
]
contiguous_data_buffers[i] = tensor_list
data_offsets[i] = 0
contiguous_partition = contiguous_data_buffers[i][
data_offsets[i]].data.copy_(partition.data)
data_offsets[i] = data_offsets[i] + 1
inputs.append(contiguous_partition)
else:
partition = partition.cpu() if PA_TO_CPU else partition
inputs.append(partition)
inputs.append(args[-1])
#just in case something funky is happening such as reuse of inputs
inputs_cuda = [item.to(cuda_device) for item in args]
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
#ctx.save_for_backward(*args)
with torch.no_grad():
outputs = run_function(*inputs_cuda)
del inputs_cuda
#with torch.cuda.stream(transport_stream):
#if PARTITION_ACTIVATIONS:
# new_args = []
# for arg, inp in zip(args,inputs):
# size= torch.tensor(arg.size())
# arg.data = inp.data
# new_args.append(arg)
# new_args.append(size)
# ctx.save_for_backward(*new_args)
if PARTITION_ACTIVATIONS:
new_args = []
for i, (arg, inp) in enumerate(zip(args, inputs)):
size = torch.tensor(arg.size())
arg.data = inp.data
new_args.append(arg)
if CONTIGUOUS_CHECKPOINTING:
numel = size.numel()
if i >= len(contiguous_size_buffers):
tmp = torch.tensor(())
contiguous_size_buffers.append(
tmp.new_empty([numel * num_layers],
dtype=size.dtype,
device=size.device))
size_offsets.append(0)
elif contiguous_size_buffers[i] is None:
tmp = torch.tensor(())
contiguous_size_buffers[i] = tmp.new_empty([numel * num_layers],
dtype=size.dtype,
device=size.device)
size_offsets[i] = 0
contiguous_size = contiguous_size_buffers[i].narrow(
0,
size_offsets[i],
numel).data.copy_(size.data)
contiguous_size = contiguous_size.view_as(size)
size_offsets[i] = size_offsets[i] + numel
new_args.append(contiguous_size)
else:
new_args.append(size)
#if dist.get_rank() == 0:
# print (f"The stored tensor is {contiguous_size} and orginal one is {size} ")
ctx.save_for_backward(*new_args)
else:
ctx.save_for_backward(*args)
if PROFILE_TIME:
timers('forward').stop()
timers.log(['forward'])
if SYNCHRONIZE:
torch.cuda.synchronize()
return outputs
@staticmethod
def backward(ctx, *args):
global timers
#see_memory_usage("In backward", force=True)
#removing pointers to the contiguous buffer memory
#so that they can be garbage collected once the checkpoints
#have been used
if SYNCHRONIZE:
torch.cuda.synchronize()
if PROFILE_TIME:
timers('backward').start()
if CONTIGUOUS_CHECKPOINTING:
global data_offsets, size_offsets
global contiguous_data_buffers, contiguous_size_buffers
for buffers in contiguous_data_buffers:
buffers = []
#frees up all the pointers to the checkpoints except for the ones
#stored by save for backward
contiguous_data_buffers = []
contiguous_size_buffers = []
data_offsets = []
size_offsets = []
#see_memory_usage("In backward checkpointing code", force=True)
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible")
global cuda_device, transport_stream, PARTITION_ACTIVATIONS
if PARTITION_ACTIVATIONS:
#with torch.cuda.stream(transport_stream):
inputs = get_full_inputs(ctx.saved_tensors,
device=cuda_device if PA_TO_CPU else None)
detached_inputs = detach_variable(inputs)
else:
inputs = ctx.saved_tensors
detached_inputs = detach_variable(inputs)
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = torch.cuda.get_rng_state()
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
# Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
# if PARTITION_ACTIVATIONS:
# current_stream=torch.cuda.current_stream()
# current_stream.wait_stream(transport_stream)
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
if isinstance(outputs, torch.Tensor):
outputs = (outputs, )
torch.autograd.backward(outputs, args)
if PROFILE_TIME:
timers('backward').stop()
timers.log(['backward'])
if SYNCHRONIZE:
torch.cuda.synchronize()
return (None, ) + tuple(inp.grad for inp in detached_inputs)
def checkpoint(function, *args):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint. """
return CheckpointFunction.apply(function, *args)
def partition_activations_in_checkpoint(partition_activation):
global PARTITION_ACTIVATIONS
PARTITION_ACTIVATIONS = partition_activation
if dist.get_rank() == 0:
print(f"**************Partition Activations {PARTITION_ACTIVATIONS}************")
def set_num_layers(nlayers):
global num_layers
num_layers = nlayers
def reset():
"""Resets memory buffers related to contiguous memory optimizations.
Should be called during eval when multiple forward propagations are
computed without any backward propagation that usually clears these
buffers.
Arguments:
None
Return:
None
"""
if CONTIGUOUS_CHECKPOINTING:
global data_offsets, size_offsets
global contiguous_data_buffers, contiguous_size_buffers
for buffers in contiguous_data_buffers:
buffers = []
#frees up all the pointers to the checkpoints except for the ones
#stored by save for backward
contiguous_data_buffers = []
contiguous_size_buffers = []
data_offsets = []
size_offsets = []
def _configure_using_config_file(deepspeed_config):
global num_layers, PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME
config = DeepSpeedConfig(deepspeed_config).activation_checkpointing_config
print(config.repr())
PARTITION_ACTIVATIONS = config.partition_activations
CONTIGUOUS_CHECKPOINTING = config.contiguous_memory_optimization
num_layers = config.number_checkpoints
PA_TO_CPU = config.cpu_checkpointing
SYNCHRONIZE = config.synchronize_checkpoint_boundary
PROFILE_TIME = config.profile
def _configure_defaults():
global mpu, num_layers, deepspeed_checkpointing_enabled
global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME
PARTITION_ACTIVATIONS = False
CONTIGUOUS_CHECKPOINTING = False
num_layers = False
PA_TO_CPU = False
SYNCHRONIZE = False
PROFILE_TIME = False
deepspeed_checkpointing_enabled = True
def configure(
mpu_,
deepspeed_config=None,
partition_activations=None,
contiguous_checkpointing=None,
num_checkpoints=None,
checkpoint_in_cpu=None,
synchronize=None,
profile=None,
):
"""Configure DeepSpeed Activation Checkpointing.
Arguments:
mpu_: Optional: An object that implements the following methods
get_model_parallel_rank/group/world_size, and get_data_parallel_rank/group/world_size
deepspeed_config: Optional: DeepSpeed Config json file when provided will be used to
configure DeepSpeed Activation Checkpointing
partition_activations: Optional: Partitions activation checkpoint across model parallel
GPUs when enabled. By default False. Will overwrite deepspeed_config if provided
contiguous_checkpointing: Optional: Copies activation checkpoints to a contiguous memory
buffer. Works only with homogeneous checkpoints when partition_activations is enabled.
Must provide num_checkpoints. By default False. Will overwrite deepspeed_config if
provided
num_checkpoints: Optional: Number of activation checkpoints stored during the forward
propagation of the model. Used to calculate the buffer size for contiguous_checkpointing
Will overwrite deepspeed_config if provided
checkpoint_in_cpu: Optional: Moves the activation checkpoint to CPU. Only works with
partition_activation. Default is false. Will overwrite deepspeed_config if provided
synchronize: Optional: Performs torch.cuda.synchronize() at the beginning and end of
each call to deepspeed.checkpointing.checkpoint for both forward and backward pass.
By default false. Will overwrite deepspeed_config if provided
profile: Optional: Logs the forward and backward time for each
deepspeed.checkpointing.checkpoint invocation. Will overwrite deepspeed_config
if provided
Returns:
None
"""
global mpu, num_layers, deepspeed_checkpointing_enabled
global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME
_configure_defaults()
if deepspeed_config is not None:
_configure_using_config_file(deepspeed_config)
if mpu_ is not None:
mpu = mpu_
if partition_activations is not None:
PARTITION_ACTIVATIONS = partition_activations
if contiguous_checkpointing is not None:
CONTIGUOUS_CHECKPOINTING = contiguous_checkpointing
if num_checkpoints is not None:
num_layers = num_checkpoints
if checkpoint_in_cpu is not None:
PA_TO_CPU = checkpoint_in_cpu
if synchronize is not None:
SYNCHRONIZE = synchronize
if profile is not None:
PROFILE_TIME = profile
if PA_TO_CPU or CONTIGUOUS_CHECKPOINTING:
assert PARTITION_ACTIVATIONS, "CPU Checkpointing/Contiguous Checkpointing is only availble with partitioned activations. Set partitioned activations to true in deepspeed config"
if CONTIGUOUS_CHECKPOINTING:
assert num_layers is not None, "Must specify the number of layers with contiguous memory checkpointing"
def is_configured():
"""True if deepspeed activation checkpointing has been configured
by calling deepspeed.checkpointing.configure, else returns false
Arguments:
None
Return:
True of configured, else False
"""
global deepspeed_checkpointing_enabled
return deepspeed_checkpointing_enabled

Просмотреть файл

@ -0,0 +1,110 @@
"""
Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""
from deepspeed.pt.deepspeed_config_utils import get_scalar_param
#########################################
# DeepSpeed Activation Checkpointing
#########################################
# Activation Checkpointing Allows to save memory by only keeping a select few
#activations for the backpropagation.
ACTIVATION_CHKPT_FORMAT = '''
Activation Checkpointing should be configured as:
"session_params": {
"activation_checkpointing": {
"partitioned_activations": [true|false],
"number_checkpoints": 100,
"contiguous_memory_optimization": [true|false],
"cpu_checkpointing": [true|false]
"profile": [true|false],
"synchronize_checkpoint_boundary": [true|false],
}
}
'''
ACT_CHKPT_PARTITION_ACTIVATIONS = 'partition_activations'
ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT = False
ACT_CHKPT_NUMBER_CHECKPOINTS = 'number_checkpoints'
ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT = None
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION = 'contiguous_memory_optimization'
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT = False
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY = 'synchronize_checkpoint_boundary'
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT = False
ACT_CHKPT_PROFILE = 'profile'
ACT_CHKPT_PROFILE_DEFAULT = False
ACT_CHKPT_CPU_CHECKPOINTING = 'cpu_checkpointing'
ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT = False
ACT_CHKPT = 'activation_checkpointing'
ACT_CHKPT_DEFAULT = {
ACT_CHKPT_PARTITION_ACTIVATIONS: ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT,
ACT_CHKPT_NUMBER_CHECKPOINTS: ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT,
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION:
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT,
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY:
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT,
ACT_CHKPT_PROFILE: ACT_CHKPT_PROFILE_DEFAULT,
ACT_CHKPT_CPU_CHECKPOINTING: ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT
}
class DeepSpeedActivationCheckpointingConfig(object):
def __init__(self, param_dict):
super(DeepSpeedActivationCheckpointingConfig, self).__init__()
self.partition_activations = None
self.contiguous_memory_optimization = None
self.cpu_checkpointing = None
self.number_checkpoints = None
self.synchronize_checkpoint_boundary = None
self.profile = None
if ACT_CHKPT in param_dict.keys():
act_chkpt_config_dict = param_dict[ACT_CHKPT]
else:
act_chkpt_config_dict = ACT_CHKPT_DEFAULT
self._initialize(act_chkpt_config_dict)
"""
For json serialization
"""
def repr(self):
return self.__dict__
def _initialize(self, act_chkpt_config_dict):
self.partition_activations = get_scalar_param(
act_chkpt_config_dict,
ACT_CHKPT_PARTITION_ACTIVATIONS,
ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT)
self.contiguous_memory_optimization = get_scalar_param(
act_chkpt_config_dict,
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION,
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT)
self.cpu_checkpointing = get_scalar_param(act_chkpt_config_dict,
ACT_CHKPT_CPU_CHECKPOINTING,
ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT)
self.number_checkpoints = get_scalar_param(act_chkpt_config_dict,
ACT_CHKPT_NUMBER_CHECKPOINTS,
ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT)
self.profile = get_scalar_param(act_chkpt_config_dict,
ACT_CHKPT_PROFILE,
ACT_CHKPT_PROFILE_DEFAULT)
self.synchronize_checkpoint_boundary = get_scalar_param(
act_chkpt_config_dict,
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY,
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT)

Просмотреть файл

@ -8,6 +8,9 @@ import logging
import json
from deepspeed.pt.deepspeed_constants import *
from deepspeed.pt.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE
from deepspeed.pt.deepspeed_config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys
from deepspeed.pt.deepspeed_zero_config import DeepSpeedZeroConfig
from deepspeed.pt.deepspeed_checkpointing_config import DeepSpeedActivationCheckpointingConfig
TENSOR_CORE_ALIGN_SIZE = 8
ADAM_OPTIMIZER = 'adam'
@ -15,13 +18,6 @@ LAMB_OPTIMIZER = 'lamb'
DEEPSPEED_OPTIMIZERS = [ADAM_OPTIMIZER, LAMB_OPTIMIZER]
def get_scalar_param(param_dict, param_name, param_default_value):
if param_name in param_dict.keys():
return param_dict[param_name]
else:
return param_default_value
def get_fp16_enabled(param_dict):
if FP16 in param_dict.keys():
return get_scalar_param(param_dict[FP16], FP16_ENABLED, FP16_ENABLED_DEFAULT)
@ -92,10 +88,20 @@ def get_sparse_gradients_enabled(param_dict):
return get_scalar_param(param_dict, SPARSE_GRADIENTS, SPARSE_GRADIENTS_DEFAULT)
def get_zero_enabled(param_dict):
def get_zero_optimization(param_dict):
return get_scalar_param(param_dict, ZERO_OPTIMIZATION, ZERO_OPTIMIZATION_DEFAULT)
def get_zero_reduce_scatter(param_dict):
return get_scalar_param(param_dict, ZERO_REDUCE_SCATTER, ZERO_REDUCE_SCATTER_DEFAULT)
def get_zero_max_elements_per_comm(param_dict):
return get_scalar_param(param_dict,
ZERO_MAX_ELEMENTS_PER_COMM,
ZERO_MAX_ELEMENTS_PER_COMM_DEFAULT)
def get_allgather_size(param_dict):
return get_scalar_param(param_dict,
ALLGATHER_SIZE,
@ -204,6 +210,10 @@ def get_wall_clock_breakdown(param_dict):
WALL_CLOCK_BREAKDOWN_DEFAULT)
def get_memory_breakdown(param_dict):
return get_scalar_param(param_dict, MEMORY_BREAKDOWN, MEMORY_BREAKDOWN_DEFAULT)
def get_tensorboard_enabled(param_dict):
if TENSORBOARD in param_dict.keys():
return get_scalar_param(param_dict[TENSORBOARD],
@ -231,10 +241,39 @@ def get_tensorboard_job_name(param_dict):
return TENSORBOARD_JOB_NAME_DEFAULT
'''Write deepspeed config files by modifying basic templates.
Can be used for quicly changing parameters via command line parameters.'''
class DeepSpeedConfigWriter:
def __init__(self, data=None):
self.data = data if data is not None else {}
def add_config(self, key, value):
self.data[key] = value
def load_config(self, filename):
self.data = json.load(open(filename,
'r'),
object_pairs_hook=dict_raise_error_on_duplicate_keys)
def write_config(self, filename):
with open(filename, 'w') as outfile:
json.dump(self.data, outfile)
class DeepSpeedConfig(object):
def __init__(self, json_file, mpu=None):
def __init__(self, json_file, mpu=None, param_dict=None):
super(DeepSpeedConfig, self).__init__()
self._param_dict = json.load(open(json_file, 'r'))
if param_dict is None:
self._param_dict = json.load(
open(json_file,
'r'),
object_pairs_hook=dict_raise_error_on_duplicate_keys)
else:
self._param_dict = param_dict
try:
self.global_rank = torch.distributed.get_rank()
if mpu is None:
@ -263,7 +302,14 @@ class DeepSpeedConfig(object):
self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict)
self.allgather_size = get_allgather_size(param_dict)
self.zero_enabled = get_zero_enabled(param_dict)
self.zero_config = DeepSpeedZeroConfig(param_dict)
self.zero_optimization_stage = self.zero_config.stage
self.zero_enabled = self.zero_optimization_stage > 0
self.activation_checkpointing_config = DeepSpeedActivationCheckpointingConfig(
param_dict)
self.gradient_clipping = get_gradient_clipping(param_dict)
self.fp16_enabled = get_fp16_enabled(param_dict)
self.loss_scale = get_loss_scale(param_dict)
@ -285,6 +331,7 @@ class DeepSpeedConfig(object):
self.scheduler_params = get_scheduler_params(param_dict)
self.wall_clock_breakdown = get_wall_clock_breakdown(param_dict)
self.memory_breakdown = get_memory_breakdown(param_dict)
self.tensorboard_enabled = get_tensorboard_enabled(param_dict)
self.tensorboard_output_path = get_tensorboard_output_path(param_dict)
self.tensorboard_job_name = get_tensorboard_job_name(param_dict)
@ -305,8 +352,8 @@ class DeepSpeedConfig(object):
f'Gradient accumulation steps: {grad_acc} has to be greater than 0'
assert train_batch == micro_batch * grad_acc * self.world_size, \
(f'Check batch related parameters. Train_batch_size is not equal'
'to micro_batch_per_gpu * gradient_acc_step * world_size'
(f'Check batch related parameters. train_batch_size is not equal'
' to micro_batch_per_gpu * gradient_acc_step * world_size'
f'{train_batch} != {micro_batch} * {grad_acc} * {self.world_size}')
def _set_batch_related_parameters(self):
@ -387,6 +434,7 @@ class DeepSpeedConfig(object):
def _do_error_check(self):
if self.zero_enabled:
assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled"
assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION)
assert self.train_micro_batch_size_per_gpu, "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU)

Просмотреть файл

@ -0,0 +1,25 @@
"""
Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""
"""
Collection of DeepSpeed configuration utilities
"""
def get_scalar_param(param_dict, param_name, param_default_value):
if param_name in param_dict.keys():
return param_dict[param_name]
else:
return param_default_value
def dict_raise_error_on_duplicate_keys(ordered_pairs):
"""Reject duplicate keys."""
d = {}
for k, v in ordered_pairs:
if k in d:
raise ValueError("Duplicate key in DeepSpeed config: %r" % (k, ))
else:
d[k] = v
return d

Просмотреть файл

@ -15,7 +15,7 @@ ROUTE_ENCODE = "encode"
# Batch size
#############################################
TRAIN_BATCH_SIZE = "train_batch_size"
TRAIN_BATCH_SIZE_DEFAULT = 1
TRAIN_BATCH_SIZE_DEFAULT = None
#############################################
# Optimizer and lr scheduler
@ -133,14 +133,27 @@ GRADIENT_CLIPPING_DEFAULT = 0.
# ZeRO optimization
#########################################
# ZeRO optimization. By default, this optimization is not enabled.
# Users can configure in ds_config.json as below example:
# Users have to configure the desired optimization (0 means disabled) in params.json as below example:
ZERO_FORMAT = '''
ZeRO optimization should be enabled as:
"zero_optimization": true,
"zero_all_gather_size": 200
"session_params": {
"zero_optimization": [0|1|2],
"zero_all_gather_size": 200
}
'''
ZERO_OPTIMIZATION = 'zero_optimization'
ZERO_OPTIMIZATION_DEFAULT = False
ZERO_OPTIMIZATION_DEFAULT = 0
ZERO_OPTIMIZATION_OPTIMIZER_STATES = 1
ZERO_OPTIMIZATION_GRADIENTS = 2
ZERO_OPTIMIZATION_WEIGHTS = 3
MAX_STAGE_ZERO_OPTIMIZATION = ZERO_OPTIMIZATION_GRADIENTS
ZERO_REDUCE_SCATTER = "zero_reduce_scatter"
ZERO_REDUCE_SCATTER_DEFAULT = True
ZERO_MAX_ELEMENTS_PER_COMM = "zero_max_elements_per_comm"
ZERO_MAX_ELEMENTS_PER_COMM_DEFAULT = 5e8
ALLGATHER_SIZE = 'allgather_size'
ALLGATHER_SIZE_DEFAULT = 500000000
@ -217,6 +230,9 @@ Wall block breakdown should be enabled as:
WALL_CLOCK_BREAKDOWN = 'wall_clock_breakdown'
WALL_CLOCK_BREAKDOWN_DEFAULT = False
MEMORY_BREAKDOWN = 'memory_breakdown'
MEMORY_BREAKDOWN_DEFAULT = False
#########################################
# Tensorboard
#########################################

Просмотреть файл

@ -8,11 +8,14 @@ import os
import warnings
import torch.distributed as dist
from torch.nn.modules import Module
from torch.distributed.distributed_c10d import _get_global_rank
from tensorboardX import SummaryWriter
from deepspeed.pt.deepspeed_timer import ThroughputTimer, SynchronizedWallClockTimer
from deepspeed.pt.deepspeed_zero_optimizer import FP16_DeepSpeedZeroOptimizer
from deepspeed.pt.zero_optimizer_stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
import deepspeed.pt.deepspeed_checkpointing as deepspeed_activation_checkpointing
from deepspeed.pt.fp16_optimizer import FP16_Optimizer
from deepspeed.pt.fp16_unfused_optimizer import FP16_UnfusedOptimizer
@ -21,8 +24,10 @@ from deepspeed.pt.deepspeed_config import DeepSpeedConfig, \
ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_OPTIMIZERS
from deepspeed.pt.deepspeed_dataloader import DeepSpeedDataLoader
from deepspeed.pt.deepspeed_constants import ROUTE_TRAIN, ROUTE_PREDICT, \
ROUTE_EVAL, TORCH_DISTRIBUTED_DEFAULT_PORT
from deepspeed.pt.deepspeed_constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
TORCH_DISTRIBUTED_DEFAULT_PORT, \
ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS
import deepspeed.pt.deepspeed_lr_schedules as lr_schedules
from deepspeed.pt.deepspeed_csr_tensor import CSRTensor
@ -96,7 +101,8 @@ class DeepSpeedLight(Module):
lr_scheduler=None,
mpu=None,
dist_init_required=None,
collate_fn=None):
collate_fn=None,
config_params=None):
super(DeepSpeedLight, self).__init__()
logging.basicConfig(level=logging.INFO,
@ -116,6 +122,7 @@ class DeepSpeedLight(Module):
self.gradient_predivide_factor = 1.0
self.gradient_average = True
self.warn_unscaled_loss = True
self.config_params = config_params
if dist_init_required is None:
dist_init_required = not dist.is_initialized()
@ -146,6 +153,9 @@ class DeepSpeedLight(Module):
# Configure distributed model
self._configure_distributed_model(model)
# Configure wall clock timer
self.timers = SynchronizedWallClockTimer()
# Throughput timer
self.tput_timer = ThroughputTimer(
batch_size=self.train_micro_batch_size_per_gpu(),
@ -163,9 +173,6 @@ class DeepSpeedLight(Module):
self._configure_lr_scheduler(lr_scheduler)
self._report_progress(0)
# Configure wall clock timer
self.timers = SynchronizedWallClockTimer()
# Bookkeeping for csr support
self.csr_tensor_module_names = set()
if self.sparse_gradients_enabled():
@ -245,6 +252,9 @@ class DeepSpeedLight(Module):
def wall_clock_breakdown(self):
return self._config.wall_clock_breakdown
def memory_breakdown(self):
return self._config.memory_breakdown
def sparse_gradients_enabled(self):
return self._config.sparse_gradients_enabled
@ -275,6 +285,30 @@ class DeepSpeedLight(Module):
def zero_allow_untested_optimizer(self):
return self._config.zero_allow_untested_optimizer
def zero_reduce_scatter(self):
return self._config.zero_config.reduce_scatter
def zero_overlap_comm(self):
return self._config.zero_config.overlap_comm
def zero_max_elements_per_comm(self):
return self._config.zero_max_elements_per_comm
def zero_optimization_stage(self):
return self._config.zero_optimization_stage
def zero_reduce_bucket_size(self):
return self._config.zero_config.reduce_bucket_size
def zero_allgather_bucket_size(self):
return self._config.zero_config.allgather_bucket_size
def zero_optimization_partition_gradients(self):
return self.zero_optimization_stage() >= ZERO_OPTIMIZATION_GRADIENTS
def zero_contiguous_gradients(self):
return self._config.zero_config.contiguous_gradients
def allgather_size(self):
return self._config.allgather_size
@ -296,8 +330,8 @@ class DeepSpeedLight(Module):
def steps_per_print(self):
return self._config.steps_per_print
def disable_allgather(self):
return self._config.disable_allgather
def zero_allgather_partitions(self):
return self._config.zero_config.allgather_partitions
def dump_state(self):
return self._config.dump_state
@ -375,7 +409,9 @@ class DeepSpeedLight(Module):
# Configure based on command line arguments
def _configure_with_arguments(self, args, mpu):
self.local_rank = args.local_rank if hasattr(args, 'local_rank') else 0
self._config = DeepSpeedConfig(args.deepspeed_config, mpu)
self._config = DeepSpeedConfig(args.deepspeed_config,
mpu,
param_dict=self.config_params)
# Validate command line arguments
def _do_args_sanity_check(self, args):
@ -390,6 +426,7 @@ class DeepSpeedLight(Module):
assert hasattr(args, 'local_rank') and type(args.local_rank) == int, \
'DeepSpeed requires integer command line parameter --local_rank'
if self.config_params is None:
assert hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None, \
'DeepSpeed requires --deepspeed_config to specify configuration file'
@ -424,7 +461,8 @@ class DeepSpeedLight(Module):
else:
self.data_parallel_group = self.mpu.get_data_parallel_group()
self.dp_world_size = self.mpu.get_data_parallel_world_size()
src_rank = self.mpu.get_model_parallel_rank()
src_rank = _get_global_rank(self.mpu.get_data_parallel_group(), 0)
print(f"global src_rank={src_rank}")
for p in self.module.parameters():
if torch.is_tensor(p):
dist.broadcast(p, src_rank, group=self.data_parallel_group)
@ -518,17 +556,42 @@ class DeepSpeedLight(Module):
return optimizer
def _configure_zero_optimizer(self, optimizer):
logging.info('Creating fp16 zero optimizer')
optimizer = FP16_DeepSpeedZeroOptimizer(
zero_stage = self.zero_optimization_stage()
logging.info('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage))
if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
assert self.zero_reduce_scatter(), 'Stage 1 only supports reduce scatter mode'
logging.info('Creating fp16 ZeRO Optimizer Stage 1')
optimizer = FP16_DeepSpeedZeroOptimizer_Stage1(
optimizer,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
dp_process_group=self.data_parallel_group,
clip_grad=self.gradient_clipping(),
all_gather_partitions=not self.disable_allgather(),
allgather_size=self.allgather_size(),
all_gather_partitions=self.zero_allgather_partitions(),
allgather_size=self.zero_allgather_bucket_size(),
max_elements_per_comm=self.zero_reduce_bucket_size(),
dp_process_group=self.data_parallel_group,
mpu=self.mpu)
elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS:
assert self.gradient_accumulation_steps() == 1, "ZeRO stage 2 does not support gradient accumulation, if you need gradient accumulation please use stage 1"
optimizer = FP16_DeepSpeedZeroOptimizer(
optimizer,
timers=self.timers,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
clip_grad=self.gradient_clipping(),
contiguous_gradients=self.zero_contiguous_gradients(),
reduce_bucket_size=self.zero_reduce_bucket_size(),
allgather_bucket_size=self.zero_allgather_bucket_size(),
dp_process_group=self.data_parallel_group,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(),
mpu=self.mpu)
else:
raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage))
logging.info('Creating fp16 zero stage {} optimizer'.format(zero_stage))
return optimizer
@ -624,6 +687,15 @@ class DeepSpeedLight(Module):
def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
if self.is_gradient_accumulation_boundary():
if self.zero_optimization_stage() == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
assert self.zero_reduce_scatter()
self.optimizer.reduce_scatter_gradients(
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor,
gradient_average=self.gradient_average)
elif self.zero_optimization_partition_gradients():
self.optimizer.overlapping_partition_gradients_reduce_epilogue()
else:
self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)
def backward(self, loss, allreduce_gradients=True):
@ -636,7 +708,7 @@ class DeepSpeedLight(Module):
# scale loss w.r.t. gradient accumulation if needed
if self.gradient_accumulation_steps() > 1:
loss = self._scale_loss(loss)
loss = self._scale_loss(loss.float())
# Log training Loss
if self.tensorboard_enabled():
@ -765,16 +837,17 @@ class DeepSpeedLight(Module):
'backward_inner_microstep',
'backward_allreduce_microstep',
'step_microstep'
])
# Log timing
if self.tensorboard_enabled():
],
memory_breakdown=self.memory_breakdown())
if self.is_gradient_accumulation_boundary():
if self.global_rank == 0:
self.summary_events = [(f'Train/Samples/elapsed_time_ms_forward', self.timers('forward').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/Samples/elapsed_time_ms_backward', self.timers('backward').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/Samples/elapsed_time_ms_backward_inner', self.timers('backward_inner').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/Samples/elapsed_time_ms_backward_allreduce', self.timers('backward_allreduce').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/Samples/elapsed_time_ms_step', self.timers('step').elapsed(reset=False) * 1000.0, self.sample_count)
if self.tensorboard_enabled() and torch.distributed.get_rank(
) == 0: # this is done before the log because log resets timers
self.summary_events = [(f'Train/elapsed_time_ms_forward', self.timers('forward').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/elapsed_time_ms_backward', self.timers('backward').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/elapsed_time_ms_backward_inner', self.timers('backward_inner').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/elapsed_time_ms_backward_allreduce', self.timers('backward_allreduce').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/elapsed_time_ms_step', self.timers('step').elapsed(reset=False) * 1000.0, self.sample_count)
]
for event in self.summary_events: # write_summary_events
self.summary_writer.add_scalar(event[0], event[1], event[2])
@ -971,19 +1044,30 @@ class DeepSpeedLight(Module):
if not os.path.exists(dirname):
os.makedirs(dirname)
def load_checkpoint(self, load_dir, tag, load_optimizer_states=True):
def load_checkpoint(self,
load_dir,
tag,
load_module_strict=True,
load_optimizer_states=True,
load_lr_scheduler_states=True):
r"""Load training checkpoint
Arguments:
load_dir: Required. Directory to load the checkpoint from
tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step.
load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match.
load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance
load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
Return:
load_path: Path of the loaded checkpoint. None if loading the checkpoint failed
client_state: State dictionary used for loading required training states in the client code.
"""
load_path, client_states = self._load_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)
load_path, client_states = self._load_checkpoint(load_dir,
tag,
load_module_strict=load_module_strict,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states)
if self.zero_optimization() and load_path is not None:
self._load_zero_checkpoint(load_dir,
@ -992,7 +1076,12 @@ class DeepSpeedLight(Module):
return load_path, client_states
def _load_checkpoint(self, load_dir, tag, load_optimizer_states=True):
def _load_checkpoint(self,
load_dir,
tag,
load_module_strict=True,
load_optimizer_states=True,
load_lr_scheduler_states=True):
load_path = self._get_ckpt_name(load_dir, tag)
@ -1005,12 +1094,13 @@ class DeepSpeedLight(Module):
logging.info('Loading checkpoint: {}'.format(load_path))
checkpoint = torch.load(load_path, map_location=lambda storage, loc: storage)
self.load_module_state_dict(checkpoint['module'])
self.load_module_state_dict(state_dict=checkpoint['module'],
strict=load_module_strict)
if not self.zero_optimization():
self.optimizer.load_state_dict(checkpoint['optimizer'],
load_optimizer_states=load_optimizer_states)
if self.lr_scheduler is not None:
if load_lr_scheduler_states and self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
self.csr_tensor_module_names = checkpoint['csr_tensor_module_names']
@ -1019,6 +1109,7 @@ class DeepSpeedLight(Module):
deepspeed_states = [
'module',
'optimizer',
'lr_scheduler',
'csr_tensor_module_names',
'skipped_steps',
'global_steps'
@ -1058,19 +1149,15 @@ class DeepSpeedLight(Module):
#There seems to be issue creating them in parallel
self._create_checkpoint_files(save_dir, tag)
try:
if self.save_non_zero_checkpoint:
self._save_checkpoint(save_dir, tag, client_state=client_state)
if self.save_zero_checkpoint:
self._save_zero_checkpoint(save_dir, tag)
except:
logging.error(f'Failed Saving model checkpoint to {save_dir} with tag {tag}')
return False
return True
def _create_checkpoint_files(self, save_dir, tag):
#checkpoint files are created sequentially
for rank in range(self.world_size):
if rank == self.global_rank:
@ -1114,14 +1201,8 @@ class DeepSpeedLight(Module):
torch.save(state, save_path)
def _save_zero_checkpoint(self, save_path, tag):
try:
zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag)
#self._ensure_directory_exists(zero_checkpoint_name)
except:
logging.error(
f'Failed Saving Zero model checkpoint to {save_path} with tag {tag}')
zero_sd = {'optimizer_state_dict': self.optimizer.state_dict()}
torch.save(zero_sd, zero_checkpoint_name)
logging.info('zero checkpoint saved {}'.format(zero_checkpoint_name))

16
deepspeed/pt/deepspeed_timer.py Normal file → Executable file
Просмотреть файл

@ -69,13 +69,27 @@ class SynchronizedWallClockTimer:
self.timers[name] = self.Timer(name)
return self.timers[name]
def log(self, names, normalizer=1.0, reset=True):
@staticmethod
def memory_usage():
alloc = "mem_allocated: {:.4f} GB".format(torch.cuda.memory_allocated() /
(1024 * 1024 * 1024))
max_alloc = "max_mem_allocated: {:.4f} GB".format(
torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024))
cache = "cache_allocated: {:.4f} GB".format(torch.cuda.memory_cached() /
(1024 * 1024 * 1024))
max_cache = "max_cache_allocated: {:.4f} GB".format(
torch.cuda.max_memory_cached() / (1024 * 1024 * 1024))
return " | {} | {} | {} | {}".format(alloc, max_alloc, cache, max_cache)
def log(self, names, normalizer=1.0, reset=True, memory_breakdown=False):
"""Log a group of timers."""
assert normalizer > 0.0
string = 'time (ms)'
for name in names:
elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time)
if memory_breakdown:
string += self.memory_usage()
print_rank_0(string)

60
deepspeed/pt/deepspeed_utils.py Normal file → Executable file
Просмотреть файл

@ -12,9 +12,10 @@ from torch._six import inf
class CheckOverflow(object):
'''Checks for overflow in gradient across parallel process'''
def __init__(self, param_groups=None, mpu=None):
def __init__(self, param_groups=None, mpu=None, zero_reduce_scatter=False):
self.mpu = mpu
self.params = [] if param_groups else None
self.zero_reduce_scatter = zero_reduce_scatter
if param_groups:
for group in param_groups:
for param in group:
@ -54,8 +55,8 @@ class CheckOverflow(object):
# `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params):
for p in params:
if p.grad is not None and self._has_inf_or_nan(p.grad.data):
for i, p in enumerate(params):
if p.grad is not None and self._has_inf_or_nan(p.grad.data, i):
return True
return False
@ -67,7 +68,11 @@ class CheckOverflow(object):
#torch.distributed.all_reduce(overflow_gpu,
# op=torch.distributed.ReduceOp.MAX,
# group=mpu.get_model_parallel_group())
if self.mpu is not None:
if self.zero_reduce_scatter:
torch.distributed.all_reduce(overflow_gpu,
op=torch.distributed.ReduceOp.MAX,
group=torch.distributed.group.WORLD)
elif self.mpu is not None:
torch.distributed.all_reduce(overflow_gpu,
op=torch.distributed.ReduceOp.MAX,
group=self.mpu.get_model_parallel_group())
@ -76,7 +81,7 @@ class CheckOverflow(object):
# `x` is a torch.Tensor
@staticmethod
def _has_inf_or_nan(x):
def _has_inf_or_nan(x, i):
try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x
@ -93,10 +98,25 @@ class CheckOverflow(object):
return True
else:
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
_handle_overflow(cpu_sum, x, i)
return True
return False
def _handle_overflow(cpu_sum, x, i):
import math
rank = torch.distributed.get_rank()
if rank == 0:
t_i = -1
for v_i, v in enumerate(x.data.contiguous().view(-1)):
if not math.isfinite(float(v)):
t_i = v_i
break
print(
f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}"
)
def get_grad_norm(parameters, norm_type=2, mpu=None):
"""Clips gradient norm of an iterable of parameters.
@ -221,3 +241,33 @@ def get_weight_norm(parameters, norm_type=2, mpu=None):
total_norm = -1
return total_norm
def is_model_parallel_parameter(p):
return hasattr(p, 'model_parallel') and p.model_parallel
def see_memory_usage(message):
return
if torch.distributed.is_initialized() and not torch.distributed.get_rank() == 0:
return
# Print message except when distributed but not rank 0
print(message, flush=True)
print("Memory Allocated ",
torch.cuda.memory_allocated() / (1024 * 1024 * 1024),
"GigaBytes",
flush=True)
print("Max Memory Allocated ",
torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),
"GigaBytes",
flush=True)
print("Cache Allocated ",
torch.cuda.memory_cached() / (1024 * 1024 * 1024),
"GigaBytes",
flush=True)
print("Max cache Allocated ",
torch.cuda.max_memory_cached() / (1024 * 1024 * 1024),
"GigaBytes",
flush=True)
print(" ", flush=True)

Просмотреть файл

@ -0,0 +1,150 @@
"""
Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""
import logging
#from deepspeed.pt.deepspeed_constants import *
from deepspeed.pt.deepspeed_config_utils import get_scalar_param
#########################################
# ZeRO optimization
#########################################
# ZeRO optimization. By default, this optimization is not enabled.
# Users have to configure the desired optimization (0 means disabled) in params.json as below example:
ZERO_FORMAT = '''
ZeRO optimization should be enabled as:
"session_params": {
"zero_optimization": {
"stage": [0|1|2],
"allgather_partitions": [true|false],
"allgather_bucket_size": 500000000,
"reduce_scatter": [true|false],
"contiguous_gradients" : [true|false]
"overlap_comm": [true|false],
"reduce_bucket_size": 500000000
}
}
'''
ZERO_OPTIMIZATION = 'zero_optimization'
ZERO_OPTIMIZATION_DISABLED = 0
ZERO_OPTIMIZATION_OPTIMIZER_STATES = 1
ZERO_OPTIMIZATION_GRADIENTS = 2
ZERO_OPTIMIZATION_WEIGHTS = 3
MAX_STAGE_ZERO_OPTIMIZATION = ZERO_OPTIMIZATION_GRADIENTS
ZERO_OPTIMIZATION_STAGE = 'stage'
ZERO_OPTIMIZATION_STAGE_1 = 'stage_1'
ZERO_OPTIMIZATION_STAGE_2 = 'stage_2'
ZERO_OPTIMIZATION_STAGE_3 = 'stage_3'
ZERO_OPTIMIZATION_STAGE_DEFAULT = ZERO_OPTIMIZATION_DISABLED
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS = 'allgather_partitions'
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT = True
ZERO_OPTIMIZATION_REDUCE_SCATTER = 'reduce_scatter'
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT = True
ZERO_OPTIMIZATION_OVERLAP_COMM = 'overlap_comm'
ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT = False
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS = 'contiguous_gradients'
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT = True
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE = 'reduce_bucket_size'
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT = 500000000
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE = 'allgather_bucket_size'
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT = 500000000
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEPRECATED = 'allgather_size'
ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_STAGE: ZERO_OPTIMIZATION_STAGE_DEFAULT,
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS:
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_SCATTER: ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE: ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS:
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE:
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT
}
class DeepSpeedZeroConfig(object):
def __init__(self, param_dict):
super(DeepSpeedZeroConfig, self).__init__()
self.stage = None
self.contiguous_gradients = None
self.reduce_scatter = None
self.reduce_bucket_size = None
self.allgather_partitions = None
self.allgather_bucket_size = None
self.overlap_comm = None
if ZERO_OPTIMIZATION in param_dict.keys():
zero_config_dict = param_dict[ZERO_OPTIMIZATION]
if type(zero_config_dict) is bool:
zero_config_dict = self.read_zero_config_deprecated(param_dict)
else:
zero_config_dict = ZERO_OPTIMIZATION_DEFAULT
self._initialize(zero_config_dict)
def read_zero_config_deprecated(self, param_dict):
zero_config_dict = {}
zero_config_dict[
ZERO_OPTIMIZATION_STAGE] = 1 if param_dict[ZERO_OPTIMIZATION] else 0
if zero_config_dict[ZERO_OPTIMIZATION_STAGE] > 0:
zero_config_dict[ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE] = get_scalar_param(
param_dict,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEPRECATED,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT)
logging.warning(
'DeepSpeedConfig: this format of ZeRO optimization setup is deprecated. Please use the following format: {}'
.format(ZERO_FORMAT))
return zero_config_dict
"""
For json serialization
"""
def repr(self):
return self.__dict__
def _initialize(self, zero_config_dict):
self.stage = get_scalar_param(zero_config_dict,
ZERO_OPTIMIZATION_STAGE,
ZERO_OPTIMIZATION_STAGE_DEFAULT)
self.contiguous_gradients = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS,
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT)
self.reduce_bucket_size = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT)
self.reduce_scatter = get_scalar_param(zero_config_dict,
ZERO_OPTIMIZATION_REDUCE_SCATTER,
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT)
self.overlap_comm = get_scalar_param(zero_config_dict,
ZERO_OPTIMIZATION_OVERLAP_COMM,
ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT)
self.allgather_partitions = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS,
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT)
self.allgather_bucket_size = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT)

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,803 @@
import math
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from collections import defaultdict
from deepspeed.pt.zero_utils import _initialize_parameter_parallel_groups, \
pprint
from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.pt.deepspeed_utils import get_grad_norm, CheckOverflow
def flatten_dense_tensors_sub_partition_aligned(tensor_list,
dp,
max_elements_per_comm,
pg):
num_elements = 0
for tensor in tensor_list:
num_elements = num_elements + tensor.numel()
pprint("Total number of elements in model: {}, max elements per com: {}".format(
num_elements,
max_elements_per_comm))
max_elements_per_comm = min(max_elements_per_comm, num_elements)
sub_partition_size = int(max_elements_per_comm // dp)
alignment = sub_partition_size
# if alignment == 0:
# # number of elements not divisible by dp, outside range and small model must pad with zeroes
# pad_tensor = torch.zeros(max_elements_per_comm,
# device=tensor_list[0].device,
# dtype=tensor_list[0].dtype)
# return _flatten_dense_tensors(pad_tensor)
remaining = int(num_elements % alignment)
# ensure we have equal sized sub-partitions
elements_to_add = 0
if remaining:
elements_to_add = alignment - remaining
# adding padded tensor later after we check comm alignment
pprint("adding pad tensor for alignment, {} + {}->{}".format(
num_elements,
elements_to_add,
num_elements + elements_to_add))
#num_elements = num_elements + elements_to_add
else:
padded_tensor_list = tensor_list
num_partitions = int((num_elements + elements_to_add) // sub_partition_size)
assert (num_elements + elements_to_add) % sub_partition_size == 0, "num elements should be " \
"aligned by sub partition " \
"size"
num_comm_intervals = int(num_partitions // dp)
partition_remaining = int(num_partitions % dp)
pprint("num_comm_intervals={}, partition_remaining={}".format(
num_comm_intervals,
partition_remaining))
if partition_remaining != 0:
pprint("adding pad tensor and/or extra sub partition")
# add pad tensor for alignment of comm interval, this overrules previous possibly sub-partition alignment
num_comm_intervals += 1
aligned_comm_elements = num_comm_intervals * sub_partition_size * dp
elements_to_add = aligned_comm_elements - num_elements
pad_tensor = torch.zeros(elements_to_add,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
padded_tensor_list = tensor_list + [pad_tensor]
pprint("adding pad tensor and/or extra sub partition, {} + {}->{}".format(
num_elements,
elements_to_add,
num_elements + elements_to_add))
num_elements += elements_to_add
elif elements_to_add > 0:
# add pad tensor for just alignment of sub-partition
pad_tensor = torch.zeros(elements_to_add,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
padded_tensor_list = tensor_list + [pad_tensor]
num_elements += elements_to_add
if pg is None or dist.get_rank(group=pg) == 0:
print("Number of Elements (w. padding) is ", num_elements)
padded_num_elems = 0
for p in padded_tensor_list:
padded_num_elems += p.numel()
assert num_elements == padded_num_elems, "{} != {}, rank={}".format(num_elements, padded_num_elems, dist.get_rank())
return _flatten_dense_tensors(padded_tensor_list)
def _single_range_check(current_index, start_index, end_index, tensor_size):
offset = 0
if (current_index >= start_index) and (current_index < end_index):
# Fully inside bounds
return True, offset
elif (start_index > current_index) and (start_index < (current_index + tensor_size)):
# Partially contained, compute offset
offset = start_index - current_index
return True, offset
else:
return False, offset
def _range_check(current_index, element_intervals, tensor_size):
results = []
for comm_idx, interval in enumerate(element_intervals):
start_index, end_index = interval
contained, offset = _single_range_check(current_index, start_index, end_index, tensor_size)
if contained:
results.append((contained, offset, comm_idx))
if len(results) == 0:
return [(False, 0, -1)]
return results
class FP16_DeepSpeedZeroOptimizer_Stage1(object):
"""
FP16_DeepSpeedZeroOptimizer_Stage1 designed to reduce the memory footprint
required for training large deep learning models.
For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
https://arxiv.org/abs/1910.02054
This version aligns with stage-1 in the paper above.
"""
def __init__(self,
init_optimizer,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=True,
dp_process_group=None,
partition_size=None,
mpu=None,
all_gather_partitions=True,
allgather_size=500000000,
clip_grad=0.0,
max_elements_per_comm=5e8):
if dp_process_group is not None and partition_size is not None:
raise ValueError("Cannot specify both dp_process_group "
"and partition size")
if dp_process_group is None:
dp_process_group = _initialize_parameter_parallel_groups(partition_size)
if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.")
self.optimizer = init_optimizer
self.verbose = verbose
self.dp_process_group = dp_process_group
# TODO: automatically turn off if #params > some_limit
self.all_gather_partitions = all_gather_partitions
self.allgather_size = allgather_size
self.max_elements_per_comm = max_elements_per_comm
print("max_elements_per_comm={}".format(max_elements_per_comm))
# param flattened by groups
self.fp16_groups = []
self.fp16_groups_flat = []
# Setup bookkeeping data structures depending on partitioning type
# parallel_sub_partitioned_fp16_groups[group-idx] -> [comm-ids] -> [rank-ids]
self.parallel_sub_partitioned_fp16_groups = []
# same underlying data as above but viewed as: [groups] -> [rank-ids] -> [comm-ids]
self.parallel_comm_sub_partitioned_fp16_groups = []
# 32-bit sub-partitions of the parallel partitioned parameters
# that this process will update
self.local_sub_partitions_of_fp32_groups = []
# param partition info
# parameters in each group that will not be updated by this process directly
self.params_not_local = []
# parameters that will be updated by this process directly
self.params_in_rank_sub_partitions = []
# parameter offsets for parameters in sub-partitions. Parameter
# boundaries may not align with sub-partition boundaries
# so we need to keep track of the offsets
self.params_in_rank_sub_partitions_offsets = []
# number of elements per sub-partition in each group
self.sub_partition_sizes = []
# number of communication intervals for each group
self.num_comm_intervals_per_group = []
local_rank = dist.get_rank(group=self.dp_process_group)
# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
# push this group to list before modify
self.fp16_groups.append(param_group['params'])
# flattens all tensors into single 1d tensor aligned with sub-partition size for later dividing
# RS: create aligned sub-partitions
self.fp16_groups_flat.append(
flatten_dense_tensors_sub_partition_aligned(
tensor_list=self.fp16_groups[i],
dp=dist.get_world_size(group=self.dp_process_group),
max_elements_per_comm=self.max_elements_per_comm,
pg=self.dp_process_group))
# TODO: I don't think this does anything?
# set model fp16 weight to slices of flattened buffer
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data
# divide the flat weights into near equal partition equal to the data parallel degree
# each process will compute on a different part of the partition
# RS: split into two layer list -> [comm-id] -> [sub-partitions per rank]
comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \
self.get_data_parallel_sub_partitions(
tensor=self.fp16_groups_flat[i],
max_elements_per_comm=self.max_elements_per_comm,
world_size=dist.get_world_size(
group=self.dp_process_group),
dp_process_group=self.dp_process_group
)
self.parallel_comm_sub_partitioned_fp16_groups.append(
comm_partitions) # comm -> rank
self.parallel_sub_partitioned_fp16_groups.append(
dp_sub_partitions) # rank -> comm
self.sub_partition_sizes.append(sub_partition_size)
self.num_comm_intervals_per_group.append(num_comm_intervals)
# data_parallel_partitions = self.get_data_parallel_partitions(self.fp16_groups_flat[i])
# self.parallel_partitioned_fp16_groups.append(data_parallel_partitions)
# a partition of the fp32 master weights that will be updated by this process
# RS: store/detach/cast our local sub-partitions
local_sub_partitions = []
for sub_partition in self.parallel_sub_partitioned_fp16_groups[i][
local_rank]:
fp32_sub_partition = sub_partition.clone().float().detach()
fp32_sub_partition.requires_grad = True
local_sub_partitions.append(fp32_sub_partition)
self.local_sub_partitions_of_fp32_groups.append(local_sub_partitions)
# modify optimizer of have flat master weight
# self.single_partition_of_fp32_groups[i].requires_grad = True # keep this in case internal optimizer uses it
param_group['params'] = self.local_sub_partitions_of_fp32_groups[i]
# RS: divide up the sub-partitions and keep track of offsets for each param
# partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size(group=self.dp_process_group)
params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, \
params_not_local = self.get_all_sub_partition_info(
tensor_list=self.fp16_groups[i],
all_element_intervals=element_intervals,
local_rank=local_rank,
world_size=dist.get_world_size(group=self.dp_process_group)
)
self.params_in_rank_sub_partitions.append(params_in_rank_sub_partition)
self.params_not_local.append(params_not_local)
self.params_in_rank_sub_partitions_offsets.append(
params_in_rank_sub_partitions_offsets)
# we may have a way of fusing dynamic scale. Do not support for now
if dynamic_loss_scale:
if dynamic_loss_args is None:
self.loss_scaler = DynamicLossScaler()
else:
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
self.dynamic_loss_scale = True
else:
self.dynamic_loss_scale = False
self.loss_scaler = LossScaler(scale=static_loss_scale)
self.cur_iter = 0
self.mpu = mpu
self.clip_grad = clip_grad
self.overflow = False
self.overflow_checker = CheckOverflow(self.fp16_groups,
mpu=self.mpu,
zero_reduce_scatter=True)
@staticmethod
def get_data_parallel_sub_partitions(tensor,
max_elements_per_comm,
world_size,
dp_process_group=None):
total_num_elements = tensor.numel()
# if total elements is less than our max, revert to splitting into dp partitions
max_elements_per_comm = min(total_num_elements, max_elements_per_comm)
sub_partition_size = int(max_elements_per_comm // world_size)
# Ensure partition alignment was done correctly
num_sub_partitions = int(total_num_elements // sub_partition_size)
assert total_num_elements % sub_partition_size == 0, "{} % {} != 0".format(total_num_elements, sub_partition_size)
# Ensure comm interval alignment was done correctly.
num_comm_intervals = int(num_sub_partitions // world_size)
assert num_sub_partitions % world_size == 0, "{} % {} != 0".format(num_sub_partitions, world_size)
if not dist.is_initialized() or dist.get_rank(group=dp_process_group) == 0:
print("**** partition info:")
print("\t total_num_elements=", total_num_elements)
print("\t world_size=", world_size)
print("\t max_elements_per_comm=", max_elements_per_comm)
print("\t sub_partition_size=", sub_partition_size)
print("\t num_sub_partitions=", num_sub_partitions)
print("\t num_comm_intervals=", num_comm_intervals)
print("****")
# [comm_id] -> [rank]
comm_partitions = []
for _ in range(num_comm_intervals):
comm_partitions.append([])
start = 0
comm_id = 0
element_intervals = defaultdict(
list) # [rank] -> [(start,end), (start,end), ...]
for idx in range(num_sub_partitions):
rank_id = idx % world_size
sub_partition = tensor.narrow(0, start, sub_partition_size)
element_intervals[rank_id].append((start, start + sub_partition_size))
comm_partitions[comm_id].append(sub_partition)
start = start + sub_partition_size
if rank_id == (world_size - 1):
comm_id += 1
# [rank] -> [comm_id]
sub_partitions = []
for _ in range(world_size):
sub_partitions.append([])
for comm_id, partitions in enumerate(comm_partitions):
for rank_id, partition in enumerate(partitions):
sub_partitions[rank_id].append(partition)
return comm_partitions, sub_partitions, element_intervals, sub_partition_size, num_comm_intervals
@staticmethod
def get_all_sub_partition_info(tensor_list,
all_element_intervals,
local_rank,
world_size):
params_not_local = []
# [rank] -> [comm-id] -> [param/offset]
params_in_rank_sub_partition = []
params_in_rank_sub_partitions_offsets = []
for rank in range(world_size):
params_in_local_sub_partition = []
local_sub_partition_offsets = []
comm_tensor_list = []
comm_offset_list = []
current_index = 0
prev_comm_idx = 0
for iii, tensor in enumerate(tensor_list):
tensor_size = tensor.numel()
#if local_rank == 0:
# #print("rank={}, current_index={}, tensor_size={}, tensor-idx={}".format(rank,
# current_index, tensor_size, iii))
results_list = _range_check(current_index,
all_element_intervals[rank],
tensor_size)
for contained, offset, comm_idx in results_list:
#if local_rank == 0:
# print("rank={}, contained={}, offset={}, comm_idx={}".format(rank, contained,
# offset, comm_idx))
if contained:
if prev_comm_idx != comm_idx:
params_in_local_sub_partition.append(comm_tensor_list)
comm_tensor_list = []
local_sub_partition_offsets.append(comm_offset_list)
comm_offset_list = []
comm_tensor_list.append(tensor)
comm_offset_list.append(offset)
prev_comm_idx = comm_idx
elif rank == local_rank:
params_not_local.append(tensor)
current_index = current_index + tensor_size
#assert len(comm_tensor_list) > 0
#assert len(comm_offset_list) > 0
params_in_local_sub_partition.append(comm_tensor_list)
local_sub_partition_offsets.append(comm_offset_list)
params_in_rank_sub_partition.append(params_in_local_sub_partition)
params_in_rank_sub_partitions_offsets.append(local_sub_partition_offsets)
return params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, params_not_local
@staticmethod
def get_flat_sub_partitions(comm_tensor_list,
comm_param_offsets,
sub_partition_size,
dtype,
num_comm_intervals=None,
default_device=None,
return_partition_params=False):
partition_params = []
final_param_offsets = []
flat_sub_partitions = []
for tensor_list, param_offsets in zip(comm_tensor_list, comm_param_offsets):
flat_tensor_list = []
current_size = 0
my_offsets = []
my_params = []
if dtype is None:
dtype = tensor_list[0].dtype
for i, tensor in enumerate(tensor_list):
if tensor.grad is None:
tensor.grad = torch.zeros(tensor.size(),
dtype=tensor.dtype,
device=tensor.device)
param = tensor
tensor = tensor.grad
num_elements = tensor.numel()
tensor_offset = 0
#we need to offset to get to the right element
if i == 0 and param_offsets[i] > 0:
tensor_offset = param_offsets[i]
num_elements = num_elements - tensor_offset
# We don't need all elements of the tensor if this tensor is
# larger than we have space for in our curr sub-partition
if num_elements > (sub_partition_size - current_size):
num_elements = sub_partition_size - current_size
#we need a narrow view of the tensor based on the tensor offset and number of elements that
#we need from this tensor
if tensor_offset > 0 or num_elements < tensor.numel():
flat_tensor_list.append(tensor.contiguous().view(-1).narrow(
0,
int(tensor_offset),
int(num_elements)).to(dtype))
else:
flat_tensor_list.append(tensor.to(dtype))
my_params.append(param)
#remember offset into partition and #elems for this tensor
my_offsets.append((current_size, num_elements))
current_size = current_size + num_elements
#this means its the last partition and does not align with the dp boundary. We need to pad before flattening
if current_size < sub_partition_size:
my_offsets.append((None, None))
my_params.append(None)
if len(tensor_list) == 0:
assert default_device != None
flat_tensor_list.append(
torch.zeros(int(sub_partition_size - current_size),
dtype=dtype,
device=default_device))
else:
flat_tensor_list.append(
torch.zeros(int(sub_partition_size - current_size),
dtype=dtype,
device=tensor_list[0].device))
partition_params.append(my_params) #flat_tensor_list)
final_param_offsets.append(my_offsets)
assert len(flat_tensor_list) == len(my_offsets), "{} {}".format(len(flat_tensor_list), len(my_offsets))
flat_sub_partitions.append(_flatten_dense_tensors(flat_tensor_list))
if num_comm_intervals is not None and len(
flat_sub_partitions) < num_comm_intervals:
#print("padding w. sub partitions to ensure uniform communication")
device = flat_sub_partitions[0].device
for _ in range(num_comm_intervals - len(flat_sub_partitions)):
flat_sub_partitions.append(
torch.zeros(int(sub_partition_size),
dtype=dtype,
device=device))
partition_params.append([None])
final_param_offsets.append([(None, None)])
if return_partition_params:
assert len(flat_sub_partitions) == len(partition_params)
assert len(partition_params) == len(final_param_offsets), "{} {}".format(len(partition_params), len(final_param_offsets))
return flat_sub_partitions, partition_params, final_param_offsets
return flat_sub_partitions
def zero_grad(self, set_grads_to_None=True):
"""
Zero FP16 parameter grads.
"""
# FP32 grad should never exist.
# For speed, set model fp16 grad to None by default
for group in self.fp16_groups:
for p in group:
if set_grads_to_None:
p.grad = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
def free_grad_in_param_list(self, param_list):
for p in param_list:
if isinstance(p, list):
for _p in p:
_p.grad = None
else:
p.grad = None
def reduce_scatter_gradients(self,
postscale_gradients,
gradient_predivide_factor,
gradient_average):
world_size = dist.get_world_size(group=self.dp_process_group)
local_rank = dist.get_rank(group=self.dp_process_group)
for i, group in enumerate(self.fp16_groups):
partition_param_map = {}
param_partition_map = {}
my_params = set()
# [rank] -> [comm] -> partition
num_comm_intervals = self.num_comm_intervals_per_group[i]
all_sub_partitions = []
for rank in range(world_size):
# gsp is list of partitions indexed by comm_idx
#FIXME: currently hardcoding fp16, should infer dtype
grad_sub_partitions, partition_params, param_offsets = self.get_flat_sub_partitions(
comm_tensor_list=self.params_in_rank_sub_partitions[i][rank],
comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i][rank],
sub_partition_size=self.sub_partition_sizes[i],
dtype=torch.half, #self.params_in_rank_sub_partitions[i][rank][0][0].dtype,
num_comm_intervals=self.num_comm_intervals_per_group[i],
default_device='cuda', #self.params_in_rank_sub_partitions[i][rank][0][0].device,
return_partition_params=True)
all_sub_partitions.append(grad_sub_partitions)
# create map from partition -> params in that partition
for comm_idx, part in enumerate(grad_sub_partitions):
partition_param_map[part] = (partition_params[comm_idx],
param_offsets[comm_idx])
for comm_idx, params in enumerate(partition_params):
for pidx, p in enumerate(params):
# store the parameters we care about locally
if rank == local_rank:
my_params.add(p)
# map from param -> partitions
if p in param_partition_map:
param_partition_map[p].append(grad_sub_partitions[comm_idx])
else:
param_partition_map[p] = [grad_sub_partitions[comm_idx]]
assert len(grad_sub_partitions) == num_comm_intervals
if not postscale_gradients:
raise NotImplementedError("pre-scale_gradients is not implemented")
all_comm_partitions = []
for comm_idx in range(num_comm_intervals):
single_comm_all_partitions = []
for rank in range(world_size):
single_comm_all_partitions.append(all_sub_partitions[rank][comm_idx])
dist.reduce_scatter(output=single_comm_all_partitions[local_rank],
input_list=single_comm_all_partitions,
group=self.dp_process_group)
if gradient_average:
for partition in single_comm_all_partitions:
partition.mul_(gradient_predivide_factor / world_size)
all_comm_partitions.append(single_comm_all_partitions)
for p in my_params:
partitions = param_partition_map[p]
parts = []
for part in partitions:
params, offsets = partition_param_map[part]
found = False
for p_idx, _p in enumerate(params):
if p.__hash__() == _p.__hash__():
found = True
if offsets[p_idx][0] is not None:
my_part = part.narrow(0,
offsets[p_idx][0],
offsets[p_idx][1])
parts.append(my_part)
assert found
if p is not None:
updated_grad = _unflatten_dense_tensors(torch.cat(parts), [p])
p.grad.copy_(updated_grad[0])
def step(self, closure=None):
# First compute norm for all group so we know if there is overflow
self.overflow = self.overflow_checker.check()
prev_scale = self.loss_scale
self._update_scale(self.overflow)
if self.overflow:
self.zero_grad()
if self.verbose:
print("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale,
self.loss_scale))
return self.overflow
norm_groups = []
local_sub_partitions_grad_groups = []
partition_id = dist.get_rank(group=self.dp_process_group)
for i, group in enumerate(self.fp16_groups):
#TODO RS: update get grad norm to support sub partitions
norm_groups.append(get_grad_norm(group, mpu=self.mpu))
#RS: update free grads w.r.t. sub partitions
#free gradients for all the parameters that are not updated by this process
self.free_grad_in_param_list(self.params_not_local[i])
#create flat gradients for parameters updated by this process
#tensor_list, first_offset, partition_size, dtype
#single_grad_partition = self.get_flat_partition(
# tensor_list=self.params_in_partition[i],
# first_offset=self.first_offset[i],
# partition_size=self.partition_size[i],
# dtype=self.single_partition_of_fp32_groups[i].dtype
#)
#TODO RS: can we safely use dtype of the first sub-partition? i think so
local_grad_sub_partitions = self.get_flat_sub_partitions(
comm_tensor_list=self.params_in_rank_sub_partitions[i][partition_id],
comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i]
[partition_id],
sub_partition_size=self.sub_partition_sizes[i],
dtype=self.local_sub_partitions_of_fp32_groups[i][0].dtype,
num_comm_intervals=self.num_comm_intervals_per_group[i],
default_device=self.local_sub_partitions_of_fp32_groups[i][0].device)
#RS: update all our local params with sub-partition grads
#print("self.local_sub_partitions_of_fp32_groups[i]={}, local_grad_sub_partitions={}".format(len(self.local_sub_partitions_of_fp32_groups[i]), len(local_grad_sub_partitions)))
for idx, sub_partition_param in enumerate(self.local_sub_partitions_of_fp32_groups[i]):
sub_partition_param.grad = local_grad_sub_partitions[idx]
#self.single_partition_of_fp32_groups[i].grad = single_grad_partition
#RS: update free grads for sub-partitions
#release all the gradient since we have already created a necessary copy in dp_grad_partition
self.free_grad_in_param_list(
self.params_in_rank_sub_partitions[i][partition_id])
local_sub_partitions_grad_groups.append(local_grad_sub_partitions)
#RS: update unscale/clip with sub partitions
self.unscale_and_clip_grads(local_sub_partitions_grad_groups, norm_groups)
self.optimizer.step()
#RS: clear our sub partition grads
#get rid of the fp32 gradients. Not needed anymore
for group in self.local_sub_partitions_of_fp32_groups:
for idx, sub_partition_param in enumerate(group):
sub_partition_param.grad = None
#group.grad = None
#NOTE RS: removed norm_groups outer loop from original code, i don't think it's needed
#RS: copy all sub-partition fp32 data to fp16 sub partitions
# copy fp32 param data to fp16 partitions w.r.t. our local rank
for fp16_all_sub_partitions, fp32_local_sub_partitions in zip(self.parallel_sub_partitioned_fp16_groups, self.local_sub_partitions_of_fp32_groups):
for local_sub_partition_param_fp16, local_sub_partition_param_fp32 in zip(fp16_all_sub_partitions[partition_id], fp32_local_sub_partitions):
local_sub_partition_param_fp16.data.copy_(
local_sub_partition_param_fp32.data)
#RS: all_gather/broadcast sub-partitions in separate comm calls
#gather the updated weights from everyone
for fp16_all_sub_partitions in self.parallel_comm_sub_partitioned_fp16_groups:
for comm_id, sub_partitions in enumerate(fp16_all_sub_partitions):
dist.all_gather(sub_partitions,
sub_partitions[partition_id],
group=self.dp_process_group)
# TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)):
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data
return self.overflow
def unscale_and_clip_grads(self, grad_groups_flat, norm_groups):
total_norm = 0.0
for norm in norm_groups:
total_norm += norm**2.0
total_norm = math.sqrt(total_norm)
# compute combined scale factor for this group
combined_scale = self.loss_scale
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
if clip > 1:
combined_scale = clip * self.loss_scale
for grad in grad_groups_flat:
if isinstance(grad, list):
sub_partitions = grad
for g in sub_partitions:
g.data.mul_(1. / combined_scale)
else:
grad.data.mul_(1. / combined_scale)
def backward(self, loss, retain_graph=False):
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
def _update_scale(self, has_overflow=False):
self.loss_scaler.update_scale(has_overflow)
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self):
return self.loss_scaler.loss_scale
def _set_loss_scale(self, value):
self.loss_scaler.cur_scale = value
loss_scale = property(_get_loss_scale, _set_loss_scale)
cur_scale = property(_get_loss_scale, _set_loss_scale)
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
state_dict = {}
state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
state_dict[
'local_sub_partitions_of_fp32_groups'] = self.local_sub_partitions_of_fp32_groups
return state_dict
def load_state_dict(self, state_dict, load_optimizer_states=True):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# I think it should actually be ok to reload the optimizer before the model.
self.loss_scaler = state_dict['loss_scaler']
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.overflow = state_dict['overflow']
if load_optimizer_states:
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
for curr_group, saved_group in zip(self.local_sub_partitions_of_fp32_groups, state_dict['local_sub_partitions_of_fp32_groups']):
for curr_param, saved_param in zip(curr_group, saved_group):
curr_param.data.copy_(saved_param.data)

Просмотреть файл

@ -0,0 +1,24 @@
import torch
import torch.distributed as dist
def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
data_parallel_size = int(dist.get_world_size())
if parameter_parallel_size is None:
parameter_parallel_size = int(data_parallel_size)
print(data_parallel_size, parameter_parallel_size)
assert data_parallel_size % parameter_parallel_size == 0, \
'world size should be divisible by parameter parallel size'
rank = dist.get_rank()
my_group = None
for i in range(dist.get_world_size() // parameter_parallel_size):
ranks = range(i * parameter_parallel_size, (i + 1) * parameter_parallel_size)
group = torch.distributed.new_group(ranks)
if rank in ranks:
my_group = group
return my_group
def pprint(msg):
if not dist.is_initialized() or dist.get_rank() == 0:
print(msg)

Просмотреть файл

@ -29,6 +29,14 @@ collections:
tutorials:
output: true
permalink: /:collection/:path/
order:
- getting-started.md
- azure.md
- cifar-10.md
- bert-pretraining.md
- megatron.md
- 1Cycle.md
- lrrt.md
defaults:
- scope:

Просмотреть файл

@ -18,7 +18,7 @@ lnav:
children:
- title: "Installation"
url: /getting-started/#installation
- title: "Writing Models"
- title: "Writing models"
url: /getting-started/#writing-deepspeed-models
- title: "Training"
url: /getting-started/#training
@ -37,19 +37,25 @@ lnav:
url: /docs/config-json/#communication-options
- title: "FP16"
url: /docs/config-json/#fp16-training-options
- title: "ZeRO optimizations"
url: /docs/config-json/#zero-optimizations-for-fp16-training
- title: "Logging"
url: /docs/config-json/#logging
- title: "Activation checkpointing"
url: /docs/config-json/#activation-checkpointing
- title: "Tutorials"
url: /tutorials/
children:
- title: "Getting Started on Azure"
- title: "Getting started"
url: /getting-started/
- title: "Getting started on Azure"
url: /tutorials/azure/
- title: "CIFAR-10"
url: /tutorials/cifar-10/
- title: "Megatron-LM GPT2"
url: /tutorials/megatron/
- title: "BERT Pre-training"
url: /tutorials/bert-pretraining/
- title: "Megatron-LM GPT2"
url: /tutorials/megatron/
- title: "1-Cycle Schedule"
url: /tutorials/1Cycle/
- title: "Learning Rate Range Test"

Просмотреть файл

@ -12,12 +12,6 @@ layout: archive
{% endif %}
<h2>Features Coming Soon</h2>
{% assign soon = posts | where: "sneak_preview", "true" %}
{% for post in soon %}
{% include archive-single.html %}
{% endfor %}
<h2>{{ site.data.ui-text[site.locale].recent_posts | default: "Recent Posts" }}</h2>
{% assign news = posts | where: "sneak_preview", "false" %}
{% for post in news %}

116
docs/_pages/config-json.md Normal file → Executable file
Просмотреть файл

@ -102,15 +102,8 @@ Example of ***scheduler***
| ------------------------------------------------------------ | ------- |
| Enable sparse compression of [torch.nn.Embedding](https://pytorch.org/docs/stable/nn.html#torch.nn.Embedding) gradients. | `false` |
### FP16 training options
***zero\_optimization***: [boolean]
| Description | Default |
| ------------------------------------------------------------ | ------- |
| Enable ZeRO memory optimization wrapper for FP16 Training. Currently compatible only with Adam optimizer. | `false` |
***fp16***: [dictionary]
| Description | Default |
@ -172,6 +165,66 @@ Example of ***scheduler***
| ----------------------------------- | ------- |
| Enable gradient clipping with value | `0` |
### ZeRO Optimizations for FP16 Training
Enabling and configure ZeRO memory optimizations
```json
"zero_optimization": {
"stage": [0|1|2],
"allgather_partitions": [true|false],
"allgather_bucket_size": 500000000,
"reduce_scatter": [true|false],
"reduce_bucket_size": 500000000,
"contiguous_gradients" : [true|false]
}
```
***zero\_optimization***: [dictionary]
| Description | Default |
| ------------------------------------------------------------ | ------- |
| Enable ZeRO memory optimization wrapper for FP16 Training. Currently compatible only with Adam optimizer. | `false` |
***stage***: [integer]
| Description | Default |
| ------------------------------------------------------------ | ------- |
| Chooses different stages of ZeRO Optimizer. Stage 0, 1, and 2 refer to disabled, optimizer state partitioning, and optimizer+gradient state partitiong, respectively. | `0` |
***allgather_partitions***: [boolean]
| Description | Default |
| ------------------------------------------------------------ | ------- |
| Chooses between allgather collective or a series of broadcast collectives to gather updated parameters from all the GPUs at the end of each step | `true` |
***allgather_bucket_size***: [boolean]
| Description | Default |
| ------------------------------------------------------------ | ------- |
| Number of elements allgathered at a time. Limits the memory required for the allgather for large model sizes | `500000000` |
***reduce_scatter***: [boolean]
| Description | Default |
| ------------------------------------------------------------ | ------- |
| Uses reduce or reduce scatter instead of allreduce to average gradients | `true` |
***reduce_bucket_size***: [boolean]
| Description | Default |
| ------------------------------------------------------------ | ------- |
| Number of elements reduced/allreduced at a time. Limits the memory required for the allgather for large model sizes | `500000000` |
***contiguous_gradients***: [boolean]
| Description | Default |
| ------------------------------------------------------------ | ------- |
| Copies the gradients to a contiguous buffer as they are produced. Avoids memory fragmentation during backward pass. Only useful when running very large models. | `False` |
### Logging
***steps\_per\_print***: [integer]
@ -191,3 +244,52 @@ Example of ***scheduler***
| Description | Default |
| ------------------------------------------------------------ | ------- |
| Print out state information of DeepSpeed object after initialization | `false` |
### Activation Checkpointing
```json
"activation_checkpointing": {
"partition_activations": false,
"cpu_checkpointing": false,
"contiguous_memory_optimization": false,
"number_checkpoints": null,
"synchronize_checkpoint_boundary": false,
"profile": false
}
```
***partition\_activations***: [boolean]
| Description | Default |
| ------------------------------------------------------------ | ------- |
| Enables partition activation when used with model parallelism | `false` |
***cpu\_checkpointing***: [boolean]
| Description | Default |
| ------------------------------------------------------------ | ------- |
| Offloads partitioned activations to CPU if partition_activations is enabled| `false` |
***contiguous\_memory\_optimization***: [boolean]
| Description | Default |
| ------------------------------------------------------------ | ------- |
| Copies partitioned activations so that they are contiguous in memory | `false` |
***number_checkpoints***: [integer]
| Description | Default |
| ------------------------------------------------------------ | ------- |
| Total number of activation checkpoints used to allocate memory buffer for contiguous_memoty_optimization | `None` |
***synchronize\_checkpoint\_boundary***: [boolean]
| Description | Default |
| ------------------------------------------------------------ | ------- |
| Inserts torch.cuda.synchronize() at each checkpoint boundary. | `false` |
***profile***: [boolean]
| Description | Default |
| ------------------------------------------------------------ | ------- |
| Logs the forward and backward time for each checkpoint function | `false` |

73
docs/_pages/features.md Normal file → Executable file
Просмотреть файл

@ -57,19 +57,33 @@ DeepSpeed is fully compatible with [Megatron](https://github.com/NVIDIA/Megatron
Please see the [Megatron-LM tutorial](/tutorials/megatron/) for details.
## Memory and Bandwidth Optimizations
### The Zero Redundancy Optimizer (ZeRO)
[ZeRO](https://arxiv.org/abs/1910.02054) is at the heart of DeepSpeed and
enables large model training at a scale that is simply not possible with model
parallelism alone. When enabled, ZeRO allows training models with
over 6 billion parameters without any model parallelism, and up to 100 billion
parameter models with model parallelism on current generation hardware.
## The Zero Redundancy Optimizer
The Zero Redundancy Optimizer ([ZeRO](https://arxiv.org/abs/1910.02054)) is at
the heart of DeepSpeed and enables large model training at a scale that is
simply not possible with model parallelism alone. When enabled, ZeRO allows
training models with over 13 billion parameters without any model parallelism,
and up to 200 billion parameter models with model parallelism on current
generation hardware.
For more details see the [ZeRO paper](https://arxiv.org/abs/1910.02054), [GPT
tutorial](/tutorials/megatron/) on integration with
DeepSpeed. Additional tutorials including *BERT Tutorial*: Coming Soon.
DeepSpeed.
### Optimizer State and Gradient Partitioning
Optimizer State and Gradient Partitioning in ZeRO reduces the memory consumption of the
model states (optimizer states, gradients and parmaeters) by 8x compared to standard
data parallelism by partitioning these states across data parallel process instead of
replicating them.
### Activation Partitioning
Activation Partitioning is a memory optimization in ZeRO that can reduce the memory
consumed by activations during model parallel training (MP). In MP certain
activations maybe required by all MP processes, resulting in a replication of
activations across MP GPUs. Activation Partitioning stores these activations in a
partitioned state once they are used for computation in the forward propagation. These
activations are allgathered right before they are needed again during the backward propagation.
By storing activations in a partitioned state, ZeRO in DeepSpeed can reduce the activation
memory footprint proportional to the MP degree.
### Constant Buffer Optimization (CBO)
CBO enables high network and memory throughput while restricting memory usage to a
@ -80,6 +94,17 @@ unnecessary memory overhead. CBO in DeepSpeed fuses smaller operands into approx
pre-defined sized buffer large enough to achieve great performance without the
unnecessary memory overhead.
### Contiguous Memory Optimization (CMO)
CMO reduces reduces memory fragmentation during training, preventing out of memory errors
due to lack of contiguous memory. Memory fragmentation is a result of interleaving between
short lived and long lived memory objects. During the forward propagation activation
checkpoints are long lived but the activations that recomputed are short lived. Similarly,
during the backward computation, the activation gradients are short lived while the parameter
gradients are long lived. CMO transfers activation checkpoints and parameter gradients
to contiguous buffers preventing memory fragmentation.
## Additional Memory and Bandwidth Optimizations
### Smart Gradient Accumulation
Gradient accumulation allows running larger batch size with limited memory by breaking an
effective batch into several sequential micro-batches, and averaging the parameter
@ -90,6 +115,11 @@ averaged gradients for the effective batch across all GPUs. This strategy signif
reduces the communication involved over the approach of averaging globally for each
micro-batch, specially when the number of micro-batches per effective batch is large.
### Communication Overlapping
During back propagation, DeepSpeed can overlap the communication required for averaging
parameter gradients that have already been computed with the ongoing gradient computation.
This computation communication overlap, allows DeepSpeed to achieve higher throughput even
at modest batch sizes.
## Training Features
@ -100,12 +130,23 @@ The DeepSpeed core API consists of just a handful of methods:
* argument parsing: `add_config_arguments`
* checkpointing : `load_checkpoint` and `store_checkpoint`
DeepSpeed supports all the features described in this document, via the use of these API,
DeepSpeed supports most of the features described in this document, via the use of these API,
along with a `deepspeed_config` JSON file for enabling and disabling the features.
Please see the [core API doc](https://deepspeed.readthedocs.io/) for more details.
### Activation Checkpointing API
DeepSpeed's Activation Checkpoinitng API supports activation checkpoint partitioning,
cpu checkpoiniting, and contiguous memory optimizations, while also allowing layerwise
profiling. Please see the [core API doc](https://deepspeed.readthedocs.io/) for more details.
### Gradient Clipping
```json
{
"gradient_clipping": 1.0
}
```
DeepSpeed handles gradient clipping under the hood based on the max gradient norm
specified by the user.
Please see the [core API doc](https://deepspeed.readthedocs.io/) for more details.
@ -136,8 +177,8 @@ DeepSpeed makes it easy to train with large batch sizes by enabling the LAMB Opt
For more details on LAMB, see the [LAMB paper](https://arxiv.org/pdf/1904.00962.pdf).
### Memory-Efficient Training with ZeRO Optimizer
DeepSpeed can train models up with up to 6 billion parameters without parallelism, and
models with up to 100 billion parameters with 16-way model parallelism. This leap in
DeepSpeed can train models up with up to 13 billion parameters without parallelism, and
models with up to 200 billion parameters with 16-way model parallelism. This leap in
model size is possible though the memory efficiency achieved via the ZeRO Optimizer. For
more details see [ZeRO paper](https://arxiv.org/abs/1910.02054) .
@ -174,6 +215,10 @@ file.
Please see the [core API doc](https://deepspeed.readthedocs.io/) for more details.
```json
{
"wall_clock_breakdown": true
"wall_clock_breakdown": true,
"activation_checkpointing": {
"profile": true
}
}
```

Просмотреть файл

@ -1,10 +0,0 @@
---
title: "ZeRO stage 2"
sneak_preview: true
excerpt: "Reduce memory footprint to enable training 10B models without model parallelism!"
---
* Reduce memory footprint of gradients
* Train larger models: e.g., 10B parameters on 32GPUs without model parallelism
* Train larger batch sizes
## Further updates coming soon!

Просмотреть файл

@ -0,0 +1,21 @@
---
layout: single
title: "DeepSpeed optimizes transformer kernels to achieve world's fastest BERT training record: 44 minutes on 1024 NVIDIA V100 GPUs"
excerpt: ""
categories: news
new_post: true
date: 2020-05-19 00:00:00
---
We introduce new technology to accelerate single GPU performance via
kernel optimizations. These optimizations not only create a strong
foundation for scaling out large models, but also improve the single GPU
performance of highly tuned and moderately sized models like BERT by more
than 30%, reaching a staggering performance of 66 teraflops per V100 GPU,
which is 52% of the hardware peak. **Using these optimizations as the building
block, DeepSpeed achieves the fastest BERT training record: 44 minutes on
1,024 NVIDIA V100 GPUs**, compared with the best published result
of 67 minutes on the same number and generation of GPUs.
**Code and tutorials are coming soon!**

Просмотреть файл

@ -0,0 +1,22 @@
---
layout: single
title: "ZeRO-2 empowers training models as large as 170 billion parameters up to 10x faster compared to state-of-the-art"
excerpt: ""
categories: news
new_post: true
date: 2020-05-19 01:00:00
---
ZeRO-2 expands the scope of memory optimizations in the original ZeRO by
tackling the full spectrum of memory consumption during training. More
specifically, ZeRO-2 introduces new technology to reduce the memory footprint
of gradients, activation memory, and fragmented memory, in addition to
optimizer state memory optimization in the original ZeRO. Altogether, the
memory savings empower DeepSpeed to improve the scale and speed of deep
learning training by an order of magnitude. More concretely, ZeRO-2 allows
training models as large as 170 billion parameters up to 10x faster compared
to state of the art.
For more information on using ZeRO-2, see the [Megatron tutorial](/tutorials/megatron/).
For a technical deep dive, see our [technical report](https://arxiv.org/abs/1910.02054).

Просмотреть файл

@ -2,6 +2,7 @@
title: "Getting Started"
permalink: /getting-started/
excerpt: "First steps with DeepSpeed"
date: 2020-05-15
---
## Installation

Просмотреть файл

@ -320,6 +320,43 @@ and return the states for the client model.
```
### DeepSpeed Activation Checkpoints (Optional)
DeepSpeed can reduce the activation memory during model parallel training by partitioning activation checkpoints across model parallel GPUs, or offloading them to CPU. These optimization are optional, and can be skipped unless activation memory becomes a memory bottlenck. To enable partition activation, we use the `deepspeed.checkpointing` API to replace Megatron's activation checkpointing and random state tracker APIs. The replacement should happen before the first invocation of these APIs.
a) Replace in `pretrain_gpt.py` :
```python
# Optional DeepSpeed Activation Checkpointing Features
#
if args.deepspeed and args.deepspeed_activation_checkpointing:
set_deepspeed_activation_checkpointing(args)
def set_deepspeed_activation_checkpointing(args):
deepspeed.checkpointing.configure(mpu,
deepspeed_config=args.deepspeed_config,
partition_activation=True)
mpu.checkpoint = deepspeed.checkpointing.checkpoint
mpu.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
mpu.model_parallel_cuda_manual_seed =
deepspeed.checkpointing.model_parallel_cuda_manual_seed
```
b) Replace in `mpu/transformer.py`:
```python
if deepspeed.checkpointing.is_configured():
global get_cuda_rng_tracker, checkpoint
get_cuda_rng_tracker = deepspeed.checkpoint.get_cuda_rng_tracker
checkpoint = deepspeed.checkpointing.checkpoint
```
With these replacements, various DeepSpeed activation checkpointing optimizations such as activation partitioning, contiguous checkpointing, and CPU checkpointing, can be specified with either `deepspeed.checkpointing.configure` or in the `deepspeed_config` file.
### Train scripts
Assume webtext data was prepared in previous step, to start training
Megatron-LM GPT2 model with DeepSpeed applied, execute the following command to
@ -328,13 +365,18 @@ start training.
- Single GPU run
- run `bash scripts/ds_pretrain_gpt2.sh`
- Multiple GPUs/Nodes run
- run `bash scripts/ds_pretrain_gpt2_model_parallel.sh`
- run `bash scripts/ds_zero2_pretrain_gpt2_model_parallel.sh`
## DeepSpeed Evaluation using GPT-2
## Performance Improvements
DeepSpeed enables training very large models effectively via the advanced [ZeRO
optimizer](https://arxiv.org/abs/1910.02054v2). ZeRO significantly reduces the memory
optimizer](https://arxiv.org/abs/1910.02054v2). In February, we released a sub-set
of optimizations from ZeRO in DeepSpeed that performs optimizer state partitioning.
We refer to them as ZeRO-1. In May, 2020 we extended ZeRO-1 in DeepSpeed to include
additional optimizations from ZeRO including gradient and activation partitioning,
as well as contiguous memory optimizations. We refer to this release as ZeRO-2.
ZeRO-2 significantly reduces the memory
footprint for training large models which means large models can be trained with i) less
model parallelism and ii) larger batch sizes. A lower model parallelism degree improves
training efficiency by increasing the granularity of the computation such as the matrix
@ -342,80 +384,25 @@ multiplication where performance is directly related to the size of the matrices
Furthermore, less model parallelism also results in less communication between model
parallel GPUs, which further boosts performance. Larger batch size has a similar effect
of increasing the computational granularity as well as reducing communication, also
resulting in better performance. Therefore, DeepSpeed combines ZeRO-powered data parallelism with
Megatron-LM tensor-slicing model parallelism, which is
significantly faster than using Megatron-LM alone.
resulting in better performance. Therefore, with DeepSpeed and ZeRO-2 integration into Megatron,
we elevate the model scale and speed to an entirely new level compared to Megatron alone.
The observed performance improvements depend on several factors such as the memory per
GPU, the local GPU interconnect (i.e., PCI-E vs NVLINK vs NVSwitch), the model size,
inter node network interconnect, etc. Below, we show some of the performance improvements
from using DeepSpeed over Megatron on a 16 GPU Low Bandwidth (40 Gbps) cluster and a 400 GPU DGX-2 High Bandwidth (800 Gbps) cluster.
For details please see the [ZeRO Paper](https://arxiv.org/abs/1910.02054v2). We also
present performance improvement on a 64 GPU cluster along with detailed configuration
analysis to show where the improvements come from.
![DeepSpeed-vs-Megatron](/assets/images/DeepSpeed-vs-Megatron.png)
![DeepSpeed-vs-Megatron](../assets/images/zero-full.png)
<p align="center">
<em>The figure depicts system throughput improvements of DeepSpeed (combining ZeRO-powered data parallelism with model parallelism of Nvidia Megatron-LM) over using Megatron-LM alone.</em>
<em>Figure 2: ZeRO-2 scales to 170 billion parameters, has up to 10x higher throughput, obtains super linear speedup, and improves usability by avoiding the need for code refactoring for models up to 13 billion parameters.</em>
</p>
### On Low Bandwidth GPU Cluster
The figure above shows that training 1.5B parameter model with DeepSpeed is
nearly 4x faster than without DeepSpeed on a cluster with 4 nodes, 4 GPU per
node, and 16 GPUs total. These GPUs have 16GB of memory each, and PCI-E
interconnects GPUs within a node, and 40 Gbps infiniband across nodes.
The performance improvement comes from lower model parallelism degree and
larger batch size as discussed earlier. Training 1.5B parameter model with
Megatron-LM alone requires 4-way model parallelism, and can only fit an effective
batch size of 32 using all 16 GPUs. On the other hand, DeepSpeed does not
require any model-parallelism to train this model, and can support an
effective batch size of 128 without running out of memory, resulting in
significantly higher performance.
More concretely, DeepSpeed and ZeRO-2 excel in four aspects (as visualized in Figure 2), supporting an order-of-magnitude bigger models, up to 10x faster, with superlinear scalability, and improved usability to democratize large model training. These four aspects are detailed below.
### On High bandwidth DGX-2 GPU Cluster
Each GPU on the DGX-2 cluster has 32 GB of memory, and GPUs inside a box is connected via
the high-bandwidth NVSwitch. DGX-2 nodes are connected to each other via 800 Gbps (8 x 100Gbps) infiniband interconnect. As such, running a 1.5B model on DGX-2 requires less model
parallelism, and the performance improvement from DeepSpeed for this model size is less
significant. However, at larger model sizes, Megatron still requires significantly larger
model parallelism degree, and can only run much smaller batch sizes than DeepSpeed.
Therefore, as the model sizes get larger, DeepSpeed, by coming ZeRO with Megatron model parallelism, starts to significantly outperform
using Megatron-LM alone.
Figure 2: ZeRO-2 scales to 170 billion parameters, has up to 10x higher throughput, obtains super linear speedup, and improves usability by avoiding the need for code refactoring for models up to 13 billion parameters.
Model size: State-of-the-art large models such as OpenAI GPT-2, NVIDIA Megatron-LM, Google T5, and Microsoft Turing-NLG have sizes of 1.5B, 8.3B, 11B, and 17B parameters respectively. ZeRO-2 provides system support to efficiently run models of 170 billion parameters, an order-of-magnitude bigger than these largest models (Figure 2, top left).
### Performance Improvements with Configuration Details
The figure below compares DeepSpeed with Megatron on a 64 GPU cluster with 4
DGX-2 nodes. To give the readers a clear idea of source of the performance
improvements, we also present the configuration table for both Megatron and
DeepSpeed. It shows the smallest model parallelism degree and the largest batch
size that can be used to train these models without running out of memory. As
discussed above, the tables demonstrate that DeepSpeed runs with smaller model parallelism degree
and achieves better performance.
Speed: Improved memory efficiency powers higher throughput and faster training. Figure 2 (bottom left) shows system throughput of ZeRO-2 and ZeRO-1 (both combining ZeRO-powered data parallelism with NVIDIA Megatron-LM model parallelism) as well as using the state-of-the-art model parallelism approach Megatron-LM alone (baseline in Figure 2, bottom left). ZeRO-2 runs 100-billion-parameter models on a 400 NVIDIA V100 GPU cluster with over 38 teraflops per GPU and aggregated performance over 15 petaflops. For models of the same size, ZeRO-2 is 10x faster in training speed when compared with using Megatron-LM alone and 5x faster when compared with ZeRO-1.
![DeepSpeed Performance SpeedUp](/assets/images/megatron-gpt2-perf-test.png)
<p align="center">
<em>The figure depicts system throughput improvements of DeepSpeed (combining ZeRO-powered data parallelism with model parallelism of Nvidia Megatron-LM) over using Megatron-LM alone.</em>
</p>
Scalability: We observe superlinear speedup (Figure 2, top right), where the performance more than doubles when the number of GPUs are doubled. ZeRO-2 reduces the memory footprint of the model states as we increase the data parallelism degree, allowing us to fit larger batch sizes per GPU and resulting in better performance.
Democratizing large model training: ZeRO-2 empowers model scientists to train models up to 13 billion parameters efficiently without any model parallelism that typically requires model refactoring (Figure 2, bottom right). 13 billion parameters is larger than most of the largest state-of-the-art models (such as Google T5, with 11 billion parameters). Model scientists can therefore experiment freely with large models without worrying about model parallelism. In comparison, the implementations of classic data-parallelism approaches (such as PyTorch Distributed Data Parallel) run out of memory with 1.4-billion-parameter models, while ZeRO-1 supports up to 6 billion parameters for comparison.
**a ) Megatron-LM GPT2 Baseline**
| | Model Parallelism | Data Parallelism | #gpus | batch size | layers | hidden size | attention heads | samples / sec |
| ---- | ----------------: | ---------------: | ----: | ---------: | -----: | -----------:| --------------: | ------------: |
| 1.5B | 2 | 32 | 64 | 512 | 48 | 1600 | 16 | 128.56 |
| 4B | 4 | 16 | 64 | 128 | 64 | 2304 | 16 | 49.36 |
| 8B | 4 | 16 | 64 | 128 | 72 | 3072 | 24 | 24.57 |
| 20B | 16 | 4 | 64 | 16 | 111 | 3808 | 32 | 3.42 |
**b ) Megatron-LM GPT2 with DeepSpeed**
| | Model Parallelism | Data Parallelism | #gpus | batch size | layers | hidden size | attention heads | samples / sec |
| ---- | ----------------: | ---------------: | ----: | ---------: | -----: | -----------:| --------------: | ------------: |
| 1.5B | 1 | 64 | 64 | 2048 | 48 | 1600 | 16 | 151.35 |
| 4B | 1 | 64 | 64 | 512 | 64 | 2304 | 16 | 75.13 |
| 8B | 2 | 32 | 64 | 512 | 72 | 3072 | 24 | 43.52 |
| 20B | 4 | 16 | 64 | 128 | 111 | 3808 | 32 | 12.65 |
Furthermore, in the absence of model parallelism, these models can be trained on low bandwidth clusters while still achieving significantly better throughput compared to using model parallelism. For example, the GPT-2 model can be trained nearly 4x faster with ZeRO powered data parallelism compared to using model parallelism on a four node cluster connected with 40 Gbps Infiniband interconnect, where each node have four NVIDIA 16GB V100 GPUs connected with PCI-E. Therefore, with this performance improvement, large model training is no longer limited to GPU clusters with ultra fast interconnect but also accesible on modest clusters with limited bandwidth.

0
docs/assets/images/DeepSpeed-vs-Megatron.png Executable file → Normal file
Просмотреть файл

До

Ширина:  |  Высота:  |  Размер: 96 KiB

После

Ширина:  |  Высота:  |  Размер: 96 KiB

Двоичные данные
docs/assets/images/deepspeed-speedup.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 78 KiB

0
docs/assets/images/deepspeed-throughput-seq512.png Executable file → Normal file
Просмотреть файл

До

Ширина:  |  Высота:  |  Размер: 13 KiB

После

Ширина:  |  Высота:  |  Размер: 13 KiB

Двоичные данные
docs/assets/images/zero-full.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 121 KiB

0
docs/code-docs/requirements.local.txt Normal file → Executable file
Просмотреть файл

Просмотреть файл

@ -1,3 +1,2 @@
tqdm
psutil
tensorboardX==1.8

Просмотреть файл

@ -0,0 +1,38 @@
Activation Checkpointing
========================
The activation checkpointing API's in DeepSpeed can be used to enable a range
of memory optimizations relating to activation checkpointing. These include
activation partitioning across GPUs when using model parallelism, CPU
checkpointing, contiguous memory optimizations, etc.
Please see the `DeepSpeed JSON config <https://www.deepspeed.ai/docs/config-json/>`_
for the full set.
Here we present the activation checkpointing API. Please see the enabling
DeepSpeed for `Megatron-LM tutorial <https://www.deepspeed.ai/tutorials/megatron/>`_
for example usage.
Configuring Activation Checkpointing
------------------------------------
.. autofunction:: deepspeed.checkpointing.configure
.. autofunction:: deepspeed.checkpointing.is_configured
Using Activation Checkpointing
------------------------------
.. autofunction:: deepspeed.checkpointing.checkpoint
.. autofunction:: deepspeed.checkpointing.reset
Configuring and Checkpointing Random Seeds
------------------------------------------
.. autofunction:: deepspeed.checkpointing.get_cuda_rng_tracker
.. autofunction:: deepspeed.checkpointing.model_parallel_cuda_manual_seed
.. autoclass:: deepspeed.checkpointing.CudaRNGStatesTracker
.. autoclass:: deepspeed.checkpointing.CheckpointFunction

Просмотреть файл

@ -0,0 +1,26 @@
DeepSpeed Activation Checkpointing
======================
The activation checkpointing API's in DeepSpeed can be used to enable a range of memory optimizations relating
to activation checkpointing. These include activation partitioning across
GPUs when using model parallelism, CPU Checkpointing, contiguous memory optimizations, etc.
Please see the `DeepSpeed JSON config <https://www.deepspeed.ai/docs/config-json/>`_ for the full set.
Here we present the activation checkpointing API's.
Please see the enabling DeepSpeed for Megatron-LM tutorial for usage details.
.. autofunction:: deepspeed.checkpointing.configure
.. autofunction:: deepspeed.checkpointing.is_configured
.. autofunction:: deepspeed.checkpointing.checkpoint
.. autofunction:: deepspeed.checkpointing.reset
.. autofunction:: deepspeed.checkpointing.get_cuda_rng_tracker
.. autofunction:: deepspeed.checkpointing.model_parallel_cuda_manual_seed
.. autoclass:: deepspeed.checkpointing.CudaRNGStatesTracker
.. autoclass:: deepspeed.checkpointing.CheckpointFunction

Просмотреть файл

@ -71,25 +71,9 @@ html_context = {
from unittest.mock import MagicMock
sys.path.insert(0, os.path.abspath('../../../'))
# Prepend module names to class descriptions?
add_module_names = True
class Mock(MagicMock):
@classmethod
def __getattr__(cls, name):
return MagicMock()
autoclass_content = 'both'
MOCK_MODULES = [
'torch',
'torch.utils',
'torch.utils.data',
'torch.utils.data.distributed',
'torch._utils',
'torch.cuda',
'torch.nn.modules',
'torch.nn',
'torch.distributed',
'torch.distributed.distributed_c10d',
'torch.optim',
'torch._six'
]
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
autodoc_mock_imports = ["torch", "apex", "mpi4py", "tensorboardX"]

Просмотреть файл

@ -1,15 +1,35 @@
DeepSpeed
=========
Model Setup
-----------
.. toctree::
:maxdepth: 2
:caption: Contents:
initialize
checkpointing
Training API
------------
.. toctree::
:maxdepth: 2
training
Checkpointing API
-----------------
.. toctree::
:maxdepth: 2
model-checkpointing
activation-checkpointing
Indices and tables
==================
------------------
* :ref:`genindex`
* :ref:`modindex`

Просмотреть файл

@ -1,6 +1,30 @@
Initializing DeepSpeed
======================
Training Setup
==============
.. _deepspeed-args:
Argument Parsing
----------------
DeepSpeed uses the `argparse <https://docs.python.org/3/library/argparse.html>`_ library to
supply commandline configuration to the DeepSpeed runtime. Use ``deepspeed.add_config_arguments()``
to add DeepSpeed's builtin arguments to your application's parser.
.. code-block:: python
parser = argparse.ArgumentParser(description='My training script.')
parser.add_argument('--local_rank', type=int, default=-1,
help='local rank passed from distributed launcher')
# Include DeepSpeed configuration arguments
parser = deepspeed.add_config_arguments(parser)
cmd_args = parser.parse_args()
.. autofunction:: deepspeed.add_config_arguments
.. _deepspeed-init:
Training Initialization
-----------------------
The entrypoint for all training with DeepSpeed is ``deepspeed.initialize()``.
Example usage:

Просмотреть файл

@ -0,0 +1,12 @@
Model Checkpointing
===================
DeepSpeed provides routines for checkpointing model state during training.
Loading Training Checkpoints
----------------------------
.. autofunction:: deepspeed.DeepSpeedLight.load_checkpoint
Saving Training Checkpoints
---------------------------
.. autofunction:: deepspeed.DeepSpeedLight.save_checkpoint

Просмотреть файл

@ -0,0 +1,29 @@
Training API
============
:func:`deepspeed.initialize` returns a *model engine* in its first argument
of type ``DeepSpeedLight``. This engine is used to progress training:
.. code-block:: python
for step, batch in enumerate(data_loader):
#forward() method
loss = model_engine(batch)
#runs backpropagation
model_engine.backward(loss)
#weight update
model_engine.step()
Forward Propagation
-------------------
.. autofunction:: deepspeed.DeepSpeedLight.forward
Backward Propagation
--------------------
.. autofunction:: deepspeed.DeepSpeedLight.backward
Optimizer Step
--------------
.. autofunction:: deepspeed.DeepSpeedLight.step

69
docs/index.md Normal file → Executable file
Просмотреть файл

@ -8,11 +8,11 @@ DeepSpeed is a deep learning optimization library that makes distributed trainin
efficient, and effective.
<p align="center"><i><b>10x Larger Models</b></i></p>
<p align="center"><i><b>5x Faster Training</b></i></p>
<p align="center"><i><b>10x Faster Training</b></i></p>
<p align="center"><i><b>Minimal Code Change</b></i></p>
DeepSpeed can train DL models with over a hundred billion parameters on current
generation of GPU clusters, while achieving over 5x in system performance
generation of GPU clusters, while achieving over 10x in system performance
compared to the state-of-art. Early adopters of DeepSpeed have already produced
a language model (LM) with over 17B parameters called
[Turing-NLG](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft),
@ -22,9 +22,9 @@ establishing a new SOTA in the LM category.
{% assign news = site.posts | where: "sneak_preview", "false" %}
{% for post in news limit:5 %}
{% if post.link %}
* [{{ post.title }}]({{ post.link }})
* [{{ post.date | date: "%Y/%m/%d" }}] [{{ post.title }}]({{ post.link }}) {% if post.new_post %} <span style="color:dodgerblue">**NEW!**</span> {% endif %}
{% else %}
* [{{ post.title }}]({{ post.url }})
* [{{ post.date | date: "%Y/%m/%d"}}] [{{ post.title }}]({{ post.url }}) {% if post.new_post %} <span style="color:dodgerblue">**NEW!**</span> {% endif %}
{% endif %}
{% endfor %}
@ -54,19 +54,20 @@ DeepSpeed achieves high performance and fast convergence through a combination o
efficiency optimizations on compute/communication/memory/IO and effectiveness
optimizations on advanced hyperparameter tuning and optimizers. For example:
* DeepSpeed trains BERT-large to parity in 14 hours using 64 GPUs (4 DGX-2 boxes) and in
3.7 hours using 256 GPUs (16 DGX-2 boxes).
* <span style="color:dodgerblue">DeepSpeed trains BERT-large to parity in 44
mins using 1024 V100 GPUs (64 DGX-2 boxes) and in 2.4 hours using 256 GPUs
(16 DGX-2 boxes).</span>
**BERT-large Training Times**
| Devices | Source | Training Time (hours) |
| ------------- | --------- | ---------------------:|
| 64 TPUs | Google | 96 |
| 64 V100 GPUs | DeepSpeed | **14** |
| 256 V100 GPUs | NVIDIA | 3.9 |
| 256 V100 GPUs | DeepSpeed | **3.7** |
| Devices | Source | Training Time |
| -------------- | --------- | ---------------------:|
| 1024 V100 GPUs | DeepSpeed | **44** min|
| 256 V100 GPUs | DeepSpeed | **2.4** hr|
| 64 V100 GPUs | DeepSpeed | **8.68** hr|
| 16 V100 GPUs | DeepSpeed | **33.22** hr|
*BERT Tutorial*: Coming Soon
*BERT codes and tutorials will be available soon.*
* DeepSpeed trains GPT2 (1.5 billion parameters) 3.75x faster than state-of-art, NVIDIA
Megatron on Azure GPUs.
@ -77,37 +78,42 @@ optimizations on advanced hyperparameter tuning and optimizers. For example:
## Memory efficiency
DeepSpeed provides memory-efficient data parallelism and enables training models without
model parallelism. For example, DeepSpeed can train models with up to 6 billion parameters on
model parallelism. For example, DeepSpeed can train models with up to 13 billion parameters on
NVIDIA V100 GPUs with 32GB of device memory. In comparison, existing frameworks (e.g.,
PyTorch's Distributed Data Parallel) run out of memory with 1.5 billion parameter models.
PyTorch's Distributed Data Parallel) run out of memory with 1.4 billion parameter models.
DeepSpeed reduces the training memory footprint through a novel solution called Zero
Redundancy Optimizer (ZeRO). Unlike basic data parallelism where memory states are
replicated across data-parallel processes, ZeRO partitions model states to save
significant memory. The current implementation (stage 1 of ZeRO) reduces memory by up to
4x relative to the state-of-art. You can read more about ZeRO in our [paper](https://arxiv.org/abs/1910.02054).
replicated across data-parallel processes, ZeRO partitions model states and gradients to save
significant memory. Furthermore, it also reduces activation memory and fragmented memory.
The current implementation (ZeRO-2) reduces memory by up to
8x relative to the state-of-art. You can read more about ZeRO in our [paper](https://arxiv.org/abs/1910.02054), and
in our blog posts related to
[ZeRO-1](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/). <!-- and [ZeRO-2](linklink). -->
With this impressive memory reduction, early adopters of DeepSpeed have already
produced a language model (LM) with over 17B parameters called
[Turing-NLG](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft),
<a href="https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft">
<span style="color:dodgerblue">Turing-NLG</span></a>,
establishing a new SOTA in the LM category.
## Scalability
DeepSpeed supports efficient data parallelism, model parallelism, and their
combination. ZeRO boosts the scaling capability and efficiency further.
* DeepSpeed provides system support to run models up to 100 billion parameters,
10x larger than the state-of-art (8 billion NVIDIA GPT, 11 billion Google T5).
* DeepSpeed can run large models more efficiently, up to 6x faster for models with
various sizes spanning 1.5B to 100B. More specifically, the data parallelism powered by ZeRO
* <span style="color:dodgerblue">DeepSpeed provides system support to run models up to 170 billion parameters,
10x larger than the state-of-art (8 billion NVIDIA GPT, 11 billion Google T5).</span>
* <span style="color:dodgerblue">DeepSpeed can run large models more efficiently, up to 10x
faster for models with
various sizes spanning 1.5B to 170B.</span> More specifically, the data parallelism powered by ZeRO
is complementary and can be combined with different types of model parallelism. It allows
DeepSpeed to fit models using lower degree of model parallelism and higher batch size, offering
significant performance gains compared to using model parallelism alone.
*Read more*: [technical report](https://arxiv.org/abs/1910.02054),
*Read more*: [ZeRO paper](https://arxiv.org/abs/1910.02054),
and [GPT tutorial](/tutorials/megatron).
![DeepSpeed-vs-Megatron](/assets/images/DeepSpeed-vs-Megatron.png)
![DeepSpeed Speedup](/assets/images/deepspeed-speedup.png)
<p align="center">
<em>The figure depicts system throughput improvements of DeepSpeed (combining ZeRO-powered data parallelism with model parallelism of NVIDIA Megatron-LM) over using Megatron-LM alone.</em>
</p>
@ -123,7 +129,7 @@ convergence to desired accuracy.
## Good Usability
Only a few lines of code changes are needed to enable a PyTorch model to use DeepSpeed and ZeRO. Compared to current model parallelism libraries, DeepSpeed does not require a code redesign or model refactoring. It also does not put limitations on model dimensions (such as number of attention heads, hidden sizes, and others), batch size, or any other training parameters. For models of up to six billion parameters, you can use ZeRO-powered data parallelism conveniently without requiring model parallelism, while in contrast, standard data parallelism will run out of memory for models with more than 1.3 billion parameters. In addition, DeepSpeed conveniently supports flexible combination of ZeRO-powered data parallelism with custom model parallelisms, such as tensor slicing of NVIDIA's Megatron-LM.
Only a few lines of code changes are needed to enable a PyTorch model to use DeepSpeed and ZeRO. Compared to current model parallelism libraries, DeepSpeed does not require a code redesign or model refactoring. It also does not put limitations on model dimensions (such as number of attention heads, hidden sizes, and others), batch size, or any other training parameters. For models of up to 13 billion parameters, you can use ZeRO-powered data parallelism conveniently without requiring model parallelism, while in contrast, standard data parallelism will run out of memory for models with more than 1.4 billion parameters. In addition, DeepSpeed conveniently supports flexible combination of ZeRO-powered data parallelism with custom model parallelisms, such as tensor slicing of NVIDIA's Megatron-LM.
## Features
@ -137,12 +143,17 @@ overview](features) for descriptions and usage.
* [Model Parallelism](features.md#model-parallelism)
* Support for Custom Model Parallelism
* Integration with Megatron-LM
* [Memory and Bandwidth Optimizations](features.md#memory-and-bandwidth-optimizations)
* The Zero Redundancy Optimizer (ZeRO)
* Constant Buffer Optimization (CBO)
* [The Zero Redundancy Optimizer (ZeRO)](features.md#the-zero-redundancy-optimizer)
* Optimizer State and Gradient Partitioning
* Activation Partitioning
* Constant Buffer Optimization
* Contiguous Memory Optimization
* [Additional Memory and Bandwidth Optimizations](features.md#additional-memory-and-bandwidth-optimizations)
* Smart Gradient Accumulation
* Communication/Computation Overlap
* [Training Features](features.md#training-features)
* Simplified training API
* Activation Checkpointing API
* Gradient Clipping
* Automatic loss scaling with mixed precision
* [Training Optimizers](features.md#training-optimizers)

Просмотреть файл

@ -56,6 +56,19 @@ class BingBertSquadFuncTestCase(BaseTestCase):
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_gpu4_fp16_zero2(self):
test_config = {
"gpus": 4,
"deepspeed": False,
"json": "deepspeed_bsz24_fp16_zero2_config.json",
"max_steps": 8,
"max_epoch_steps": 4,
"other_args": "--fp16 --print_steps 1"
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_gpu1_fp16(self):
test_config = {
"gpus": 1,
@ -151,6 +164,7 @@ class BingBertSquadFuncTestCase(BaseTestCase):
def suite():
suite = unittest.TestSuite()
suite.addTest(BingBertSquadFuncTestCase('test_gpu4_fp16'))
suite.addTest(BingBertSquadFuncTestCase('test_gpu4_fp16_zero2'))
suite.addTest(BingBertSquadFuncTestCase('test_gpu1_fp16'))
suite.addTest(BingBertSquadFuncTestCase('test_gpu4_fp32'))
suite.addTest(BingBertSquadFuncTestCase('test_gpu1_fp32'))

Просмотреть файл

@ -1,14 +1,6 @@
{
"tensorboard": {
"enabled": false,
"job_name": "MyJob"
},
"zero_optimization": true,
"disable_allgather": false,
"allgather_size": 200000,
"wall_clock_breakdown": false,
"train_batch_size": 24,
"train_micro_batch_size_per_gpu": 3,
"train_micro_batch_size_per_gpu": 6,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
@ -21,5 +13,8 @@
"gradient_clipping": 1.0,
"fp16": {
"enabled": true
},
"zero_optimization": {
"stage": 1
}
}

Просмотреть файл

@ -0,0 +1,20 @@
{
"train_batch_size": 24,
"train_micro_batch_size_per_gpu": 6,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 3e-5,
"weight_decay": 0.0,
"bias_correction": false
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": true
},
"zero_optimization": {
"stage": 2
}
}

Просмотреть файл

@ -1,6 +1,6 @@
{
"train_batch_size": 24,
"train_micro_batch_size_per_gpu": 3,
"train_micro_batch_size_per_gpu": 6,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",

Просмотреть файл

@ -122,7 +122,7 @@ echo "deepspeed: ${enable_deepspeed}"
echo "other_args: ${other_args}"
EFFECTIVE_BATCH_SIZE=${batch_size}
MAX_GPU_BATCH_SIZE=3
MAX_GPU_BATCH_SIZE=6
PER_GPU_BATCH_SIZE=$((EFFECTIVE_BATCH_SIZE/num_gpus))
if [[ $PER_GPU_BATCH_SIZE -lt $MAX_GPU_BATCH_SIZE ]]; then
GRAD_ACCUM_STEPS=1

Просмотреть файл

@ -2,10 +2,11 @@
"train_batch_size": 4,
"gradient_accumulation_steps": 1,
"steps_per_print": 1,
"zero_optimization": true,
"zero_optimization": {
"stage":1
},
"optimizer": {
"type": "Adam",
"legacy_fusion": false,
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0

Просмотреть файл

@ -0,0 +1,23 @@
{
"train_batch_size": 4,
"gradient_accumulation_steps": 1,
"steps_per_print": 1,
"zero_optimization": {
"stage":2
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0
}
},
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
}
}

Просмотреть файл

@ -2,12 +2,14 @@
"train_batch_size": 8,
"gradient_accumulation_steps": 1,
"steps_per_print": 1,
"zero_optimization": false,
"zero_optimization": {
"stage":0
},
"optimizer": {
"type": "Adam",
"legacy_fusion": false,
"params": {
"lr": 0.00015
"lr": 0.00015,
"max_grad_norm": 1.0
}
},

Просмотреть файл

@ -2,10 +2,11 @@
"train_batch_size": 8,
"gradient_accumulation_steps": 1,
"steps_per_print": 1,
"zero_optimization": true,
"zero_optimization":{
"stage":1
},
"optimizer": {
"type": "Adam",
"legacy_fusion": false,
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0

Просмотреть файл

@ -0,0 +1,28 @@
{
"train_batch_size": 8,
"gradient_accumulation_steps": 1,
"steps_per_print": 1,
"zero_optimization": {
"stage":2
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0
}
},
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"activation_checkpointing": {
"partition_activations": true,
"contiguous_memory_optimization": true
}
}

Просмотреть файл

@ -2,10 +2,11 @@
"train_batch_size": 4,
"gradient_accumulation_steps": 1,
"steps_per_print": 1,
"zero_optimization": true,
"zero_optimization": {
"stage":2
},
"optimizer": {
"type": "Adam",
"legacy_fusion": false,
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0

Просмотреть файл

@ -2,11 +2,10 @@
"train_batch_size": 16,
"gradient_accumulation_steps": 1,
"steps_per_print": 1,
"zero_optimization": true,
"zero_optimization": 1,
"disable_allgather": true,
"optimizer": {
"type": "Adam",
"legacy_fusion": false,
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0

5
tests/model/Megatron_GPT2/ds_config_perf_bs32.json Normal file → Executable file
Просмотреть файл

@ -2,11 +2,12 @@
"train_batch_size": 32,
"gradient_accumulation_steps": 1,
"steps_per_print": 1,
"zero_optimization": true,
"zero_optimization": {
"stage":1
},
"disable_allgather": true,
"optimizer": {
"type": "Adam",
"legacy_fusion": false,
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0

Просмотреть файл

@ -2,11 +2,10 @@
"train_batch_size": 8,
"gradient_accumulation_steps": 1,
"steps_per_print": 1,
"zero_optimization": true,
"zero_optimization": 1,
"disable_allgather": true,
"optimizer": {
"type": "Adam",
"legacy_fusion": false,
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0

Просмотреть файл

@ -85,6 +85,7 @@ gpt_options=" \
--checkpoint-activations \
--checkpoint-num-layers ${ckpt_num_layers} \
--fp16 \
--cache-dir /tmp/cache_dir \
--log-interval 1 \
${other_args} \
${ds_opt} \
@ -92,7 +93,7 @@ gpt_options=" \
"
work_dir="../../../DeepSpeedExamples/Megatron-LM/"
run_cmd="(cd ${work_dir} && deepspeed --num_gpus $gpus pretrain_gpt2.py ${gpt_options})"
run_cmd="(cd ${work_dir} && deepspeed --num_nodes $nodes --num_gpus $gpus pretrain_gpt2.py ${gpt_options})"
echo ${run_cmd}
eval ${run_cmd}

Просмотреть файл

@ -43,7 +43,7 @@ class GPT2CheckpointTestCase(BaseTestCase):
def tearDown(self):
os.chdir(self.save_dir)
def test_mp4_gpu16_node1_with_zero(self):
def test_mp4_gpu16_node1_with_zero1(self):
test_config = {
"mp": 2,
"gpus": 4,
@ -55,12 +55,34 @@ class GPT2CheckpointTestCase(BaseTestCase):
"seq_length": 256,
"heads": 16,
"deepspeed": True,
"tag": "ds_zero",
"tag": "ds_zero1",
"zero": True,
"other_args": "",
"checkpoint_name": "ckpt_mp4_gpu16_w_zero",
"checkpoint_name": "ckpt_mp4_gpu16_w_zero1",
"checkpoint_interval": 1000,
"json": "ds_config_func_bs8.json",
"json": "ds_config_func_bs8_zero1.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp4_gpu16_node1_with_zero2(self):
test_config = {
"mp": 2,
"gpus": 4,
"nodes": 1,
"bs": 8,
"steps": 1100,
"layers": 2,
"hidden_size": 256,
"seq_length": 256,
"heads": 16,
"deepspeed": True,
"tag": "ds_zero2",
"zero": True,
"other_args": "",
"checkpoint_name": "ckpt_mp4_gpu16_w_zero2",
"checkpoint_interval": 1000,
"json": "ds_config_func_bs8_zero2.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
@ -184,7 +206,8 @@ class GPT2CheckpointTestCase(BaseTestCase):
def checkpoint_suite():
suite = unittest.TestSuite()
suite.addTest(GPT2CheckpointTestCase('test_mp4_gpu16_node1_with_zero'))
suite.addTest(GPT2CheckpointTestCase('test_mp4_gpu16_node1_with_zero1'))
suite.addTest(GPT2CheckpointTestCase('test_mp4_gpu16_node1_with_zero2'))
suite.addTest(GPT2CheckpointTestCase('test_mp4_gpu16_node1_without_zero'))
return suite

Просмотреть файл

@ -43,7 +43,7 @@ class GPT2FuncTestCase(BaseTestCase):
def tearDown(self):
os.chdir(self.save_dir)
def test_mp1_gpu1_node1(self):
def test_mp1_gpu1_node1_zero1(self):
test_config = {
"mp": 1,
"gpus": 1,
@ -55,13 +55,13 @@ class GPT2FuncTestCase(BaseTestCase):
"seq_length": 256,
"heads": 12,
"deepspeed": False,
"json": "ds_config_func_bs4.json",
"json": "ds_config_func_bs4_zero1.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp1_gpu2_node1(self):
def test_mp1_gpu2_node1_zero1(self):
test_config = {
"mp": 1,
"gpus": 2,
@ -73,13 +73,13 @@ class GPT2FuncTestCase(BaseTestCase):
"seq_length": 256,
"heads": 12,
"deepspeed": False,
"json": "ds_config_func_bs8.json",
"json": "ds_config_func_bs8_zero1.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp2_gpu4_node1(self):
def test_mp2_gpu4_node1_zero1(self):
test_config = {
"mp": 2,
"gpus": 4,
@ -91,16 +91,13 @@ class GPT2FuncTestCase(BaseTestCase):
"seq_length": 256,
"heads": 12,
"deepspeed": False,
"json": "ds_config_func_bs8.json",
"json": "ds_config_func_bs8_zero1.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
succ = self.run_partition_activations_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp4_gpu4_node1(self):
def test_mp4_gpu4_node1_zero1(self):
test_config = {
"mp": 4,
"gpus": 4,
@ -112,7 +109,82 @@ class GPT2FuncTestCase(BaseTestCase):
"seq_length": 256,
"heads": 12,
"deepspeed": False,
"json": "ds_config_func_bs8.json",
"json": "ds_config_func_bs8_zero1.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp1_gpu1_node1_zero2(self):
test_config = {
"mp": 1,
"gpus": 1,
"nodes": 1,
"bs": 4,
"steps": 1000,
"layers": 12,
"hidden_size": 768,
"seq_length": 256,
"heads": 12,
"deepspeed": False,
"json": "ds_config_func_bs4_zero2.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp1_gpu2_node1_zero2(self):
test_config = {
"mp": 1,
"gpus": 2,
"nodes": 1,
"bs": 8,
"steps": 1000,
"layers": 12,
"hidden_size": 768,
"seq_length": 256,
"heads": 12,
"deepspeed": False,
"json": "ds_config_func_bs8_zero2.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp2_gpu4_node1_zero2(self):
test_config = {
"mp": 2,
"gpus": 4,
"nodes": 1,
"bs": 8,
"steps": 1000,
"layers": 12,
"hidden_size": 768,
"seq_length": 256,
"heads": 12,
"deepspeed": False,
"json": "ds_config_func_bs8_zero2.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
succ = self.run_partition_activations_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp4_gpu4_node1_zero2(self):
test_config = {
"mp": 4,
"gpus": 4,
"nodes": 1,
"bs": 8,
"steps": 1000,
"layers": 12,
"hidden_size": 768,
"seq_length": 256,
"heads": 12,
"deepspeed": False,
"json": "ds_config_func_bs8_zero2.json",
}
succ = self.run_test(test_config, 0.01)
@ -144,11 +216,12 @@ class GPT2FuncTestCase(BaseTestCase):
print("\n")
print("{0}: starting......".format(self.id()))
baseline_prefix = "gpt2_func_"
prefix = "gpt2_partition_activation_"
# baseline run...
test_config["deepspeed"] = False
base_file = self.gen_output_name(test_config, prefix)
base_file = self.gen_output_name(test_config, baseline_prefix)
# skip baseline run if it exists.
if not self.has_loss_data(base_file):
@ -159,7 +232,7 @@ class GPT2FuncTestCase(BaseTestCase):
# DeepSpeed run...
test_config["deepspeed"] = True
test_config["other_args"] = "--partition-activations"
test_config["other_args"] = "--deepspeed-activation-checkpointing"
print("{0}: DeepSpeed run.".format(self.id()))
test_file = self.gen_output_name(test_config, prefix)
self.run_gpt2_test(test_config, test_file)
@ -217,10 +290,16 @@ class GPT2FuncTestCase(BaseTestCase):
def suite():
suite = unittest.TestSuite()
suite.addTest(GPT2FuncTestCase('test_mp1_gpu1_node1'))
suite.addTest(GPT2FuncTestCase('test_mp1_gpu2_node1'))
suite.addTest(GPT2FuncTestCase('test_mp2_gpu4_node1'))
suite.addTest(GPT2FuncTestCase('test_mp4_gpu4_node1'))
suite.addTest(GPT2FuncTestCase('test_mp1_gpu1_node1_zero1'))
suite.addTest(GPT2FuncTestCase('test_mp1_gpu2_node1_zero1'))
suite.addTest(GPT2FuncTestCase('test_mp2_gpu4_node1_zero1'))
suite.addTest(GPT2FuncTestCase('test_mp4_gpu4_node1_zero1'))
suite.addTest(GPT2FuncTestCase('test_mp1_gpu1_node1_zero2'))
suite.addTest(GPT2FuncTestCase('test_mp1_gpu2_node1_zero2'))
suite.addTest(GPT2FuncTestCase('test_mp2_gpu4_node1_zero2'))
suite.addTest(GPT2FuncTestCase('test_mp4_gpu4_node1_zero2'))
suite.addTest(GPT2FuncTestCase('test_optimizer_scheduler'))
return suite

Просмотреть файл

@ -29,14 +29,16 @@ def pytest_hack(runner_result):
assert runner_result.wasSuccessful() # fail the test
def test_run():
#def test_megatron():
# runner = unittest.TextTestRunner(failfast=True)
# pytest_hack(runner.run(Megatron_GPT2.suite()))
#
#
#def test_megatron_checkpoint():
# runner = unittest.TextTestRunner(failfast=True)
# pytest_hack(runner.run(Megatron_GPT2.checkpoint_suite()))
def test_squad():
runner = unittest.TextTestRunner(failfast=True)
# Add test suites here.
pytest_hack(runner.run(Megatron_GPT2.suite()))
pytest_hack(runner.run(Megatron_GPT2.checkpoint_suite()))
pytest_hack(runner.run(BingBertSquad.suite()))
if __name__ == '__main__':
test_run()

Просмотреть файл

@ -0,0 +1,115 @@
import os
import json
import argparse
import torch
import deepspeed
from torch.utils.data.distributed import DistributedSampler
class SimpleModel(torch.nn.Module):
def __init__(self, hidden_dim, empty_grad=False):
super(SimpleModel, self).__init__()
self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
if empty_grad:
self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim)])
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
def forward(self, x, y):
hidden_dim = x
hidden_dim = self.linear(hidden_dim)
return self.cross_entropy_loss(hidden_dim, y)
def create_config_from_dict(tmpdir, config_dict):
config_path = os.path.join(tmpdir, 'temp_config.json')
with open(config_path, 'w') as fd:
json.dump(config_dict, fd)
return config_path
def get_data_loader(model, total_samples, hidden_dim, device):
batch_size = model.train_micro_batch_size_per_gpu()
train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=torch.half)
train_label = torch.empty(total_samples,
dtype=torch.long,
device=device).random_(hidden_dim)
train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
sampler = DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
sampler=sampler)
return train_loader
def get_args(tmpdir, config_dict):
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument('--zero', type=int, default=0)
args = parser.parse_args() #args=''
config_dict["zero_optimization"]["stage"] = args.zero
print('config_dict["zero_optimization"]', config_dict["zero_optimization"])
config_path = create_config_from_dict(tmpdir, config_dict)
args.deepspeed_config = config_path
return args
def print0(msg):
if torch.distributed.get_rank() == 0:
print(msg, flush=True)
rank = int(os.environ['RANK'])
print('seed:', 2222 + rank)
torch.random.manual_seed(2222 + rank)
config_dict = {
"train_batch_size": 8,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
}
},
"fp16": {
"enabled": True,
"initial_scale_power": 15
},
"zero_optimization": {
"stage": 0,
"reduce_bucket_size": 20
}
}
# "initial_scale_power": 15
args = get_args('/tmp/', config_dict)
hidden_dim = 4
model = SimpleModel(hidden_dim, empty_grad=False)
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters(),
dist_init_required=True)
def print_params(tag, model):
if torch.distributed.get_rank() == 0:
for n, p in model.named_parameters():
print0("{} {}:{}".format(tag, n, p))
data_loader = get_data_loader(model=model,
total_samples=1000,
hidden_dim=hidden_dim,
device=model.device)
#print_params('pre-train', model)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
if torch.distributed.get_rank() == 0:
print("LOSS:", loss.item())
model.backward(loss)
model.step()
#print_params('step={}'.format(n), model)
if n == 5: break

Просмотреть файл

@ -8,7 +8,7 @@ from torch.multiprocessing import Process
import pytest
# Worker timeout *after* the first worker has completed.
DEEPSPEED_UNIT_WORKER_TIMEOUT = 10
DEEPSPEED_UNIT_WORKER_TIMEOUT = 120
def distributed_test(world_size=2, backend='nccl'):

Просмотреть файл

@ -1,6 +1,7 @@
import torch
import deepspeed
from deepspeed.pt.deepspeed_zero_optimizer import FP16_DeepSpeedZeroOptimizer
from deepspeed.pt.zero_optimizer_stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
from deepspeed.pt.fp16_optimizer import FP16_Optimizer
from deepspeed.pt.fp16_unfused_optimizer import FP16_UnfusedOptimizer
@ -9,6 +10,7 @@ import argparse
import pytest
import json
import os
import numbers
from common import distributed_test
from simple_model import SimpleModel, random_dataloader, args_from_dict
@ -22,21 +24,6 @@ def compare_deepspeed_states(saved_model, loaded_model):
assert saved_model.global_steps == loaded_model.global_steps
def compare_lr_scheduler_states(saved_model, loaded_model):
if saved_model.lr_scheduler is None:
assert loaded_model.lr_scheduler is None
return
saved = saved_model.lr_scheduler.state_dict()
loaded = loaded_model.lr_scheduler.state_dict()
assert sorted(saved.keys()) == sorted(loaded.keys())
for key in saved.keys():
if isinstance(saved[key], torch.Tensor):
assert torch.equal(saved[key], loaded[key])
else:
assert saved[key] == loaded[key]
def compare_model_states(saved_model, loaded_model):
compare_deepspeed_states(saved_model, loaded_model)
@ -47,6 +34,11 @@ def compare_model_states(saved_model, loaded_model):
for p0, p1 in zip(saved_model.optimizer.single_partition_of_fp32_groups, loaded_model.optimizer.single_partition_of_fp32_groups):
assert torch.allclose(p0,p1,atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer_Stage1):
for partition0, partition1 in zip(saved_model.optimizer.local_sub_partitions_of_fp32_groups, loaded_model.optimizer.local_sub_partitions_of_fp32_groups):
for p0, p1 in zip(partition0, partition1):
assert torch.allclose(p0,p1,atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, FP16_Optimizer):
for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat):
assert torch.allclose(p0,p1,atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
@ -61,10 +53,6 @@ def compare_model_states(saved_model, loaded_model):
def compare_optimizer_states(saved_model, loaded_model, hidden_dim):
compare_model_states(saved_model, loaded_model)
assert hasattr(loaded_model, 'optimizer')
for state0, state1 in zip(saved_model.optimizer.optimizer.state.values(),
loaded_model.optimizer.optimizer.state.values()):
for s0, s1 in zip(state0.values(), state1.values()):
@ -74,11 +62,35 @@ def compare_optimizer_states(saved_model, loaded_model, hidden_dim):
assert s0 == s1
def checkpoint_correctness_verification(save_folder,
args,
def compare_lr_scheduler_states(saved_model, loaded_model):
assert hasattr(saved_model, 'lr_scheduler')
assert hasattr(loaded_model, 'lr_scheduler')
saved_scheduler = saved_model.lr_scheduler
loaded_scheduler = loaded_model.lr_scheduler
assert hasattr(saved_scheduler, 'state_dict')
assert hasattr(loaded_scheduler, 'state_dict')
saved_sd = saved_scheduler.state_dict()
loaded_sd = loaded_scheduler.state_dict()
print(f"saved_sd = {saved_sd}")
print(f"loaded_sd = {loaded_sd}")
assert saved_sd.keys() == loaded_sd.keys()
for state0, state1 in zip(saved_sd.values(), loaded_sd.values()):
if isinstance(state0, numbers.Number) and isinstance(state1, numbers.Number):
assert state0 == state1
def checkpoint_correctness_verification(args,
model,
hidden_dim,
load_optimizer_states=True):
tmpdir,
load_optimizer_states=False,
load_lr_scheduler_states=False):
ds_model, _, _,_ = deepspeed.initialize(args=args,
model=model,
@ -94,6 +106,7 @@ def checkpoint_correctness_verification(save_folder,
trained_model = ds_model
save_folder = os.path.join(tmpdir, 'saved_checkpoint')
save_tag = '1'
trained_model.save_checkpoint(save_folder, save_tag)
@ -104,14 +117,16 @@ def checkpoint_correctness_verification(save_folder,
loaded_model.load_checkpoint(save_folder,
save_tag,
load_optimizer_states=load_optimizer_states)
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states)
compare_lr_scheduler_states(trained_model, loaded_model)
compare_model_states(trained_model, loaded_model)
if load_optimizer_states:
compare_optimizer_states(trained_model, loaded_model, hidden_dim)
else:
compare_model_states(trained_model, loaded_model)
if load_lr_scheduler_states:
compare_lr_scheduler_states(trained_model, loaded_model)
def test_checkpoint_unfused_optimizer(tmpdir):
@ -156,10 +171,10 @@ def test_checkpoint_unfused_optimizer(tmpdir):
model,
hidden_dim,
load_optimizer_states):
checkpoint_correctness_verification(tmpdir,
args,
checkpoint_correctness_verification(args,
model,
hidden_dim,
tmpdir,
load_optimizer_states=load_optimizer_states)
_test_checkpoint_unfused_optimizer(args=args,
@ -198,10 +213,10 @@ def test_checkpoint_fused_optimizer(tmpdir):
@distributed_test(world_size=[2])
def _test_checkpoint_fused_optimizer(args, model, hidden_dim, load_optimizer_states):
checkpoint_correctness_verification(tmpdir,
args,
checkpoint_correctness_verification(args,
model,
hidden_dim,
tmpdir,
load_optimizer_states=load_optimizer_states)
_test_checkpoint_fused_optimizer(args=args,
@ -214,7 +229,8 @@ def test_checkpoint_fused_optimizer(tmpdir):
load_optimizer_states=False)
def test_checkpoint_zero_optimizer(tmpdir):
@pytest.mark.parametrize("zero_stage", [1, 2])
def test_checkpoint_zero_optimizer(tmpdir, zero_stage):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
@ -231,7 +247,9 @@ def test_checkpoint_zero_optimizer(tmpdir):
"fp16": {
"enabled": True
},
"zero_optimization": True
"zero_optimization": {
"stage": zero_stage
},
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
@ -240,17 +258,165 @@ def test_checkpoint_zero_optimizer(tmpdir):
@distributed_test(world_size=[2])
def _test_checkpoint_zero_optimizer(args, model, hidden_dim, load_optimizer_states):
checkpoint_correctness_verification(tmpdir,
args,
checkpoint_correctness_verification(args,
model,
hidden_dim,
tmpdir,
load_optimizer_states=load_optimizer_states)
_test_checkpoint_zero_optimizer(args=args,
model=model,
hidden_dim=hidden_dim,
load_optimizer_states=True)
_test_checkpoint_zero_optimizer(args=args,
@pytest.mark.parametrize("zero_stage", [1, 2])
def test_checkpoint_zero_no_optimizer(tmpdir, zero_stage):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"betas": [0.8,
0.999],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": zero_stage
},
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[2])
def _test_checkpoint_zero_no_optimizer(args,
model,
hidden_dim,
load_optimizer_states):
checkpoint_correctness_verification(args,
model,
hidden_dim,
tmpdir,
load_optimizer_states=load_optimizer_states)
_test_checkpoint_zero_no_optimizer(args=args,
model=model,
hidden_dim=hidden_dim,
load_optimizer_states=False)
@pytest.mark.parametrize("zero_stage", [0, 1, 2])
def test_checkpoint_lr_scheduler(tmpdir, zero_stage):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"betas": [0.8,
0.999],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": zero_stage
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.001,
"warmup_num_steps": 1000
}
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[2])
def _test_checkpoint_lr_scheduler(args,
model,
hidden_dim,
load_optimizer_states,
load_lr_scheduler_states):
checkpoint_correctness_verification(
args,
model,
hidden_dim,
tmpdir,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states)
_test_checkpoint_lr_scheduler(args=args,
model=model,
hidden_dim=hidden_dim,
load_optimizer_states=False,
load_lr_scheduler_states=True)
@pytest.mark.parametrize("zero_stage", [0, 1, 2])
def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-5
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": zero_stage
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.001,
"warmup_num_steps": 1000
}
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[2])
def _test_checkpoint_no_lr_scheduler(args,
model,
hidden_dim,
load_optimizer_states,
load_lr_scheduler_states):
checkpoint_correctness_verification(
args,
model,
hidden_dim,
tmpdir,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states)
_test_checkpoint_no_lr_scheduler(args=args,
model=model,
hidden_dim=hidden_dim,
load_optimizer_states=False,
load_lr_scheduler_states=False)

Просмотреть файл

@ -17,3 +17,19 @@ def test_only_required_fields(tmpdir):
assert run_cfg.train_batch_size == 64
assert run_cfg.train_micro_batch_size_per_gpu == 64
assert run_cfg.gradient_accumulation_steps == 1
def test_config_duplicate_key(tmpdir):
config_dict = '''
{
"train_batch_size": 24,
"train_batch_size": 24,
}
'''
config_path = os.path.join(tmpdir, 'temp_config.json')
with open(config_path, 'w') as jf:
jf.write("%s" % config_dict)
with pytest.raises(ValueError):
run_cfg = ds_config.DeepSpeedConfig(config_path)

Просмотреть файл

@ -144,7 +144,8 @@ def test_adamw_fp16_empty_grad(tmpdir):
_test_adamw_fp16_empty_grad(args=args, model=model, hidden_dim=hidden_dim)
def test_adam_fp16_onecycle_compatibility(tmpdir):
@pytest.mark.parametrize("zero_stage", [0, 1, 2])
def test_adam_fp16_zero_onecycle_compatibility(tmpdir, zero_stage):
config_dict = {
"train_batch_size": 1,
"steps_per_print": 1,
@ -171,59 +172,11 @@ def test_adam_fp16_onecycle_compatibility(tmpdir):
"fp16": {
"enabled": True
},
"zero_optimization": False
"zero_optimization": {
"stage": zero_stage
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True)
@distributed_test(world_size=[1])
def _test_adam_fp16_onecycle_compatibility(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
_test_adam_fp16_onecycle_compatibility(args=args, model=model, hidden_dim=hidden_dim)
def test_adam_fp16_zero_onecycle_compatibility(tmpdir):
config_dict = {
"train_batch_size": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"scheduler": {
"type": "OneCycle",
"params": {
"cycle_first_step_size": 16000,
"cycle_first_stair_count": 8000,
"decay_step_size": 16000,
"cycle_min_lr": 1e-06,
"cycle_max_lr": 3e-05,
"decay_lr_rate": 1e-07,
"cycle_min_mom": 0.85,
"cycle_max_mom": 0.99,
"decay_mom_rate": 0.0
}
},
"fp16": {
"enabled": True
},
"zero_optimization": True
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
@ -248,7 +201,53 @@ def test_adam_fp16_zero_onecycle_compatibility(tmpdir):
hidden_dim=hidden_dim)
def test_zero_static_scale(tmpdir):
@pytest.mark.parametrize("zero_stage", [1, 2])
def test_zero_static_scale(tmpdir, zero_stage):
config_dict = {
"train_batch_size": 4,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"fp16": {
"enabled": True,
"loss_scale": 138.
},
"zero_optimization": {
"stage": zero_stage
}
}
args = args_from_dict(tmpdir, config_dict)
@distributed_test(world_size=2)
def _test_zero_static_scale(args):
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True)
model, optim, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
# Ensure the static scaler is configured.
assert optim.dynamic_loss_scale == False
assert optim.loss_scaler.loss_scale == 138.
# Now make sure things work..
data_loader = random_dataloader(model=model,
total_samples=10,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
_test_zero_static_scale(args)
def test_zero_static_scale_deprecated_format(tmpdir):
config_dict = {
"train_batch_size": 4,
"steps_per_print": 1,
@ -291,14 +290,17 @@ def test_zero_static_scale(tmpdir):
_test_zero_static_scale(args)
def test_zero_allow_untested_optimizer(tmpdir):
@pytest.mark.parametrize("zero_stage", [1, 2])
def test_zero_allow_untested_optimizer(tmpdir, zero_stage):
config_dict = {
"train_batch_size": 4,
"steps_per_print": 1,
"fp16": {
"enabled": True,
},
"zero_optimization": True,
"zero_optimization": {
"stage": zero_stage
},
"zero_allow_untested_optimizer": False
}
args = args_from_dict(tmpdir, config_dict)
@ -317,31 +319,34 @@ def test_zero_allow_untested_optimizer(tmpdir):
_test_zero_allow_untested_optimizer(args)
def test_zero_empty_partition(tmpdir):
config_dict = {
"train_batch_size": 3,
"fp16": {
"enabled": True
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"zero_optimization": True
}
args = args_from_dict(tmpdir, config_dict)
# @pytest.mark.parametrize("zero_stage", [1])
# def test_zero_empty_partition(tmpdir, zero_stage):
# config_dict = {
# "train_batch_size": 3,
# "fp16": {
# "enabled": True
# },
# "optimizer": {
# "type": "Adam",
# "params": {
# "lr": 0.00015
# }
# },
# "zero_optimization": {
# "stage": zero_stage
# }
# }
# args = args_from_dict(tmpdir, config_dict)
@distributed_test(world_size=[3])
def _test_zero_empty_partition(args):
hidden_dim = 1
model = SimpleModel(hidden_dim)
# Ensure model has 2 parameters, to cause empty partition with DP=3
assert len(list(model.parameters())) == 2
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
model.step()
# @distributed_test(world_size=[3])
# def _test_zero_empty_partition(args):
# hidden_dim = 1
# model = SimpleModel(hidden_dim)
# # Ensure model has 2 parameters, to cause empty partition with DP=3
# assert len(list(model.parameters())) == 2
# model, _, _, _ = deepspeed.initialize(args=args,
# model=model,
# model_parameters=model.parameters())
# model.step()
_test_zero_empty_partition(args)
# _test_zero_empty_partition(args)

Просмотреть файл

@ -73,7 +73,7 @@ def test_two_output_model(tmpdir):
summed_loss = sum(loss_tuple)
scaled_loss = model.backward(summed_loss)
expected_scaled_loss = summed_loss / gradient_accumulation_steps
expected_scaled_loss = summed_loss.float() / gradient_accumulation_steps
assert scaled_loss.item() == approx(expected_scaled_loss.item())
model.step()
@ -131,7 +131,7 @@ def test_three_output_model(tmpdir):
summed_loss = sum(loss_tuple)
scaled_loss = model.backward(summed_loss)
expected_scaled_loss = summed_loss / gradient_accumulation_steps
expected_scaled_loss = summed_loss.float() / gradient_accumulation_steps
assert scaled_loss.item() == approx(expected_scaled_loss.item())
model.step()