Co-authored-by: Shaden Smith <ShadenTSmith@gmail.com>
Co-authored-by: Shaden Smith <Shaden.Smith@microsoft.com>
This commit is contained in:
Jeff Rasley 2020-09-10 02:04:17 -07:00 коммит произвёл GitHub
Родитель c76769c4ff
Коммит a8a8b3d288
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 129 добавлений и 67 удалений

@ -1 +1 @@
Subproject commit 896831c96266e12612c3e7a923d04e68d1f4dd84
Subproject commit 9e2c34e31cec99f7d5785c6a1a3b0854c322f883

Просмотреть файл

@ -11,9 +11,13 @@ library that makes distributed training easy, efficient, and effective.
<p align="center"><i><b>10x Faster Training</b></i></p>
<p align="center"><i><b>Minimal Code Change</b></i></p>
DeepSpeed can train deep learning models with over a hundred billion parameters on current
generation of GPU clusters, while achieving over 10x in system performance
compared to the state-of-art. Early adopters of DeepSpeed have already produced
DeepSpeed delivers extreme-scale model training for everyone, from data scientists training on massive supercomputers to those training on low-end clusters or even on a single GPU:
* Extreme scale: Using current generation of GPU clusters with hundreds of devices, 3D parallelism of DeepSpeed can efficiently train deep learning models with trillions of parameters.
* Extremely memory efficient: With just a single GPU, ZeRO-Offload of DeepSpeed can train models with over 10B parameters, 10x bigger than the state of arts, democratizing multi-billion-parameter model training such that many deep learning scientists can explore bigger and better models.
* Extremely long sequence length: Sparse attention of DeepSpeed powers an order-of-magnitude longer input sequence and obtains up to 6x faster execution comparing with dense transformers.
* Extremely communication efficient: 3D parallelism improves communication efficiency allows users to train multi-billion-parameter models 2–7x faster on clusters with limited network bandwidth. 1-bit Adam reduces communication volume by up to 5x while achieving similar convergence efficiency to Adam, allowing for scaling to different types of GPU clusters and networks.
Early adopters of DeepSpeed have already produced
a language model (LM) with over 17B parameters called
[Turing-NLG](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft),
establishing a new SOTA in the LM category.
@ -27,7 +31,11 @@ information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale)
# News
* [2020/09/10] [DeepSpeed: Extreme-scale model training for everyone](linklink)
* [Powering 10x longer sequences and 6x faster execution through DeepSpeed Sparse Attention](https://www.deepspeed.ai/news/2020/09/08/sparse-attention-news.html)
* [Training a trillion parameters with pipeline parallelism](https://www.deepspeed.ai/news/2020/09/09/pipeline-parallelism.html)
* [Up to 5x less communication and 3.4x faster training through 1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-news.html)
* [10x bigger model training on a single GPU with ZeRO-Offload](https://www.deepspeed.ai/news/2020/09/08/ZeRO-Offload.html)
* [2020/08/07] [DeepSpeed Microsoft Research Webinar](https://note.microsoft.com/MSR-Webinar-DeepSpeed-Registration-On-Demand.html) is now available on-demand <span style="color:dodgerblue">**[_NEW_]**</span>
* [2020/07/24] [DeepSpeed Microsoft Research Webinar](https://note.microsoft.com/MSR-Webinar-DeepSpeed-Registration-On-Demand.html) on August 6th, 2020 <span style="color:dodgerblue">**[_NEW_]**</span>
[![DeepSpeed webinar](docs/assets/images/webinar-aug2020.png)](https://note.microsoft.com/MSR-Webinar-DeepSpeed-Registration-Live.html)
@ -68,10 +76,27 @@ overview](https://www.deepspeed.ai/features/) for descriptions and usage.
* [Model Parallelism](https://www.deepspeed.ai/features/#model-parallelism)
* Support for Custom Model Parallelism
* Integration with Megatron-LM
* [Memory and Bandwidth Optimizations](https://www.deepspeed.ai/features/#memory-and-bandwidth-optimizations)
* The Zero Redundancy Optimizer (ZeRO)
* Constant Buffer Optimization (CBO)
* [Pipeline Parallelism](https://www.deepspeed.ai/tutorials/pipeline/)
* 3D Parallelism
* [The Zero Redundancy Optimizer (ZeRO)](https://www.deepspeed.ai/tutorials/zero/)
* Optimizer State and Gradient Partitioning
* Activation Partitioning
* Constant Buffer Optimization
* Contiguous Memory Optimization
* [ZeRO-Offload](https://www.deepspeed.ai/tutorials/zero-offload/)
* Leverage both CPU/GPU memory for model training
* Support 10B model training on a single GPU
* [Ultra-fast dense transformer kernels](https://www.deepspeed.ai/news/2020/05/18/bert-record.html)
* [Sparse attention](https://www.deepspeed.ai/news/2020/09/08/sparse-attention.html)
* Memory- and compute-efficient sparse kernels
* Support 10x long sequences than dense
* Flexible support to different sparse structures
* [1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html)
* Custom communication collective
* Up to 5x communication volume saving
* [Additional Memory and Bandwidth Optimizations](https://www.deepspeed.ai/features/#additional-memory-and-bandwidth-optimizations)
* Smart Gradient Accumulation
* Communication/Computation Overlap
* [Training Features](https://www.deepspeed.ai/features/#training-features)
* Simplified training API
* Gradient Clipping
@ -81,6 +106,7 @@ overview](https://www.deepspeed.ai/features/) for descriptions and usage.
* Memory bandwidth optimized FP16 Optimizer
* Large Batch Training with LAMB Optimizer
* Memory efficient Training with ZeRO Optimizer
* CPU-Adam
* [Training Agnostic Checkpointing](https://www.deepspeed.ai/features/#training-agnostic-checkpointing)
* [Advanced Parameter Search](https://www.deepspeed.ai/features/#advanced-parameter-search)
* Learning Rate Range Test

Просмотреть файл

@ -74,3 +74,5 @@ analytics:
timezone: America/Los_Angeles
breadcrumbs: true
press_release_v3: https://www.microsoft.com/en-us/research/project/ai-at-scale/

Просмотреть файл

@ -0,0 +1,20 @@
---
layout: single
title: "Training a Trillion Parameters with Pipeline Parallelism"
excerpt: ""
categories: news
new_post: true
date: 2020-09-09 00:00:00
---
DeepSpeed includes new support for pipeline parallelism! DeepSpeed's training
engine provides hybrid 3D parallelism for training models with over a
trillion parameters. In addition to scaling to the extreme, we have
demonstrated that hybrid parallelism accelerates training on clusters with
low-bandwidth network by up to 7x.
* For a brief overview and results including trillion-parameter capabilities,
see our [press release]({{ site.press_release_v3 }}).
* To get started with pipeline parallel training in DeepSpeed, we recommend our [tutorial](/tutorials/pipeline/).
* See our AlexNet example in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples).
* Read our API documentation on [readthedocs](https://deepspeed.readthedocs.io/en/latest/pipeline.html).

Двоичные данные
docs/assets/images/pp-lowbw-gpt2.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 50 KiB

Просмотреть файл

@ -10,9 +10,14 @@ efficient, and effective.
<p align="center"><i><b>10x Larger Models</b></i></p>
<p align="center"><i><b>10x Faster Training</b></i></p>
<p align="center"><i><b>Minimal Code Change</b></i></p>
DeepSpeed can train DL models with over a hundred billion parameters on current
generation of GPU clusters, while achieving over 10x in system performance
compared to the state-of-art. Early adopters of DeepSpeed have already produced
DeepSpeed delivers extreme-scale model training for everyone, from data scientists training on massive supercomputers to those training on low-end clusters or even on a single GPU:
* Extreme scale: Using current generation of GPU clusters with hundreds of devices, 3D parallelism of DeepSpeed can efficiently train deep learning models with trillions of parameters.
* Extremely memory efficient: With just a single GPU, ZeRO-Offload of DeepSpeed can train models with over 10B parameters, 10x bigger than the state of arts, democratizing multi-billion-parameter model training such that many deep learning scientists can explore bigger and better models.
* Extremely long sequence length: Sparse attention of DeepSpeed powers an order-of-magnitude longer input sequence and obtains up to 6x faster execution comparing with dense transformers.
* Extremely communication efficient: 3D parallelism improves communication efficiency allows users to train multi-billion-parameter models 2–7x faster on clusters with limited network bandwidth. 1-bit Adam reduces communication volume by up to 5x while achieving similar convergence efficiency to Adam, allowing for scaling to different types of GPU clusters and networks.
Early adopters of DeepSpeed have already produced
a language model (LM) with over 17B parameters called
[Turing-NLG](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft),
establishing a new SOTA in the LM category.
@ -23,19 +28,11 @@ initiative to enable next-generation AI capabilities at scale, where you can fin
information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale).
# What's New?
{% assign news = site.posts | where: "sneak_preview", "false" %}
{% for post in news limit:5 %}
{% if post.link %}
{% if post.image %}
* [{{ post.date | date: "%Y/%m/%d" }}] [ {{ post.title }} {% if post.new_post %} <span style="color:dodgerblue">**NEW!**</span> {% endif %} ![]({{ post.image }}) ]({{ post.link }})
{% else %}
* [{{ post.date | date: "%Y/%m/%d" }}] [{{ post.title }}]({{ post.link }}) {% if post.new_post %} <span style="color:dodgerblue">**NEW!**</span> {% endif %}
{% endif %}
{% else %}
* [{{ post.date | date: "%Y/%m/%d"}}] [{{ post.title }}]({{ post.url }}) {% if post.new_post %} <span style="color:dodgerblue">**NEW!**</span> {% endif %}
{% endif %}
{% endfor %}
* [2020/09/10] [DeepSpeed: Extreme-scale model training for everyone]({{ site.press_release_v3 }})
* [Powering 10x longer sequences and 6x faster execution through DeepSpeed Sparse Attention](https://www.deepspeed.ai/news/2020/09/08/sparse-attention-news.html)
* [Training a trillion parameters with pipeline parallelism](https://www.deepspeed.ai/news/2020/09/09/pipeline-parallelism.html)
* [Up to 5x less communication and 3.4x faster training through 1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-news.html)
* [10x bigger model training on a single GPU with ZeRO-Offload](https://www.deepspeed.ai/news/2020/09/08/ZeRO-Offload.html)
# Why DeepSpeed?
Training advanced deep learning models is challenging. Beyond model design,
@ -87,7 +84,7 @@ optimizations on advanced hyperparameter tuning and optimizers. For example:
## Memory efficiency
DeepSpeed provides memory-efficient data parallelism and enables training models without
model parallelism. For example, DeepSpeed can train models with up to 13 billion parameters on
NVIDIA V100 GPUs with 32GB of device memory. In comparison, existing frameworks (e.g.,
a single GPU. In comparison, existing frameworks (e.g.,
PyTorch's Distributed Data Parallel) run out of memory with 1.4 billion parameter models.
DeepSpeed reduces the training memory footprint through a novel solution called Zero
@ -97,7 +94,7 @@ significant memory. Furthermore, it also reduces activation memory and fragmente
The current implementation (ZeRO-2) reduces memory by up to
8x relative to the state-of-art. You can read more about ZeRO in our [paper](https://arxiv.org/abs/1910.02054), and
in our blog posts related to
[ZeRO-1](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/). <!-- and [ZeRO-2](linklink). -->
[ZeRO-1](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/) and [ZeRO-2](https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/).
With this impressive memory reduction, early adopters of DeepSpeed have already
produced a language model (LM) with over 17B parameters called
@ -105,15 +102,15 @@ produced a language model (LM) with over 17B parameters called
<span style="color:dodgerblue">Turing-NLG</span></a>,
establishing a new SOTA in the LM category.
For model scientists with limited GPU resources, ZeRO-Offload leverages both CPU and GPU memory for training large models. Using a machine with **a single GPU**, our users can run **models of up to 13 billion parameters** without running out of memory, 10x bigger than the existing approaches, while obtaining competitive throughput. This feature democratizes multi-billion-parameter model training and opens the window for many deep learning practitioners to explore bigger and better models.
## Scalability
DeepSpeed supports efficient data parallelism, model parallelism, and their
combination. ZeRO boosts the scaling capability and efficiency further.
* <span style="color:dodgerblue">DeepSpeed provides system support to run models up to 170 billion parameters,
10x larger than the state-of-art (8 billion NVIDIA GPT, 11 billion Google T5).</span>
DeepSpeed supports efficient data parallelism, model parallelism, pipeline parallelism and their
combinations, which we call 3D parallelism.
* <span style="color:dodgerblue">3D parallelism of DeepSpeed provides system support to run models with trillions of parameters, read more in our [press-release]({{ site.press_release_v3 }}) and [tutorial](/tutorials/pipeline).</span>
* <span style="color:dodgerblue">DeepSpeed can run large models more efficiently, up to 10x
faster for models with
various sizes spanning 1.5B to 170B.</span> More specifically, the data parallelism powered by ZeRO
various sizes spanning 1.5B to hundred billion.</span> More specifically, the data parallelism powered by ZeRO
is complementary and can be combined with different types of model parallelism. It allows
DeepSpeed to fit models using lower degree of model parallelism and higher batch size, offering
significant performance gains compared to using model parallelism alone.
@ -126,6 +123,15 @@ combination. ZeRO boosts the scaling capability and efficiency further.
<em>The figure depicts system throughput improvements of DeepSpeed (combining ZeRO-powered data parallelism with model parallelism of NVIDIA Megatron-LM) over using Megatron-LM alone.</em>
</p>
## Communication efficiency
Pipeline parallelism of DeepSpeed reduce communication volume during distributed training, which allows users to train multi-billion-parameter models 2–7x faster on clusters with limited network bandwidth.
![PP Figure](/assests/images/pp-lowbw-gpt2.png)
1-bit Adam reduces communication volume by up to 5x while achieving similar convergence efficiency to Adam, allowing for scaling to different types of GPU clusters and networks. [Read more here](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html).
## Supporting long sequence length
DeepSpeed offers sparse attention kernels—an instrumental technology to support long sequences of model inputs, whether for text, image, or sound. Compared with the classic dense Transformers, it powers **an order-of-magnitude longer input sequence** and obtains up to 6x faster execution with comparable accuracy. It also outperforms state-of-the-art sparse implementations with 1.5–3x faster execution. Furthermore, our sparse kernels support efficient execution of flexible sparse format and empower users to innovate on their custom sparse structures. [Read more here](https://www.deepspeed.ai/news/2020/09/08/sparse-attention.html).
## Fast convergence for effectiveness
DeepSpeed supports advanced hyperparameter tuning and large batch size
@ -142,43 +148,51 @@ Only a few lines of code changes are needed to enable a PyTorch model to use Dee
## Features
Below we provide a brief feature list, see our detailed [feature
overview](/features/) for descriptions and usage.
Below we provide a brief feature list, see our detailed [feature overview](https://www.deepspeed.ai/features/) for descriptions and usage.
* [Distributed Training with Mixed Precision](/features/#distributed-training-with-mixed-precision)
* 16-bit mixed precision
* Single-GPU/Multi-GPU/Multi-Node
* [Model Parallelism](/features/#model-parallelism)
* Support for Custom Model Parallelism
* Integration with Megatron-LM
* [The Zero Redundancy Optimizer (ZeRO)](/features/#the-zero-redundancy-optimizer)
* Optimizer State and Gradient Partitioning
* Activation Partitioning
* Constant Buffer Optimization
* Contiguous Memory Optimization
* [ZeRO-Offload](/features/#zero-offload)
* Leverage both CPU/GPU memory for model training
* Support 10B model training on a single GPU
* [Additional Memory and Bandwidth Optimizations](/features/#additional-memory-and-bandwidth-optimizations)
* Smart Gradient Accumulation
* Communication/Computation Overlap
* [Training Features](/features/#training-features)
* Simplified training API
* Activation Checkpointing API
* Gradient Clipping
* Automatic loss scaling with mixed precision
* [Training Optimizers](/features/#training-optimizers)
* Fused Adam optimizer and arbitrary `torch.optim.Optimizer`
* CPU-Adam: High-Performance vectorized Adam
* Memory bandwidth optimized FP16 Optimizer
* Large Batch Training with LAMB Optimizer
* Memory efficient Training with ZeRO Optimizer
* [Training Agnostic Checkpointing](/features/#training-agnostic-checkpointing)
* [Advanced Parameter Search](/features/#advanced-parameter-search)
* Learning Rate Range Test
* 1Cycle Learning Rate Schedule
* [Simplified Data Loader](/features/#simplified-data-loader)
* [Performance Analysis and Debugging](/features/#performance-analysis-and-debugging)
* [Distributed Training with Mixed Precision](https://www.deepspeed.ai/features/#distributed-training-with-mixed-precision)
* 16-bit mixed precision
* Single-GPU/Multi-GPU/Multi-Node
* [Model Parallelism](https://www.deepspeed.ai/features/#model-parallelism)
* Support for Custom Model Parallelism
* Integration with Megatron-LM
* [Pipeline Parallelism](https://www.deepspeed.ai/tutorials/pipeline/)
* 3D Parallelism
* [The Zero Redundancy Optimizer (ZeRO)](https://www.deepspeed.ai/tutorials/zero/)
* Optimizer State and Gradient Partitioning
* Activation Partitioning
* Constant Buffer Optimization
* Contiguous Memory Optimization
* [ZeRO-Offload](https://www.deepspeed.ai/tutorials/zero-offload/)
* Leverage both CPU/GPU memory for model training
* Support 10B model training on a single GPU
* [Ultra-fast dense transformer kernels](https://www.deepspeed.ai/news/2020/05/18/bert-record.html)
* [Sparse attention](https://www.deepspeed.ai/news/2020/09/08/sparse-attention.html)
* Memory- and compute-efficient sparse kernels
* Support 10x long sequences than dense
* Flexible support to different sparse structures
* [1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html)
* Custom communication collective
* Up to 5x communication volume saving
* [Additional Memory and Bandwidth Optimizations](https://www.deepspeed.ai/features/#additional-memory-and-bandwidth-optimizations)
* Smart Gradient Accumulation
* Communication/Computation Overlap
* [Training Features](https://www.deepspeed.ai/features/#training-features)
* Simplified training API
* Gradient Clipping
* Automatic loss scaling with mixed precision
* [Training Optimizers](https://www.deepspeed.ai/features/#training-optimizers)
* Fused Adam optimizer and arbitrary `torch.optim.Optimizer`
* Memory bandwidth optimized FP16 Optimizer
* Large Batch Training with LAMB Optimizer
* Memory efficient Training with ZeRO Optimizer
* CPU-Adam
* [Training Agnostic Checkpointing](https://www.deepspeed.ai/features/#training-agnostic-checkpointing)
* [Advanced Parameter Search](https://www.deepspeed.ai/features/#advanced-parameter-search)
* Learning Rate Range Test
* 1Cycle Learning Rate Schedule
* [Simplified Data Loader](https://www.deepspeed.ai/features/#simplified-data-loader)
* [Performance Analysis and Debugging](https://www.deepspeed.ai/features/#performance-analysis-and-debugging)
# Contributing