[DeepSpeed docs] new information (#9610)
* how to specify a specific gpu * new paper * expand on buffer sizes * style * where to find config examples * specific example * small updates
This commit is contained in:
Родитель
1fbaa3c117
Коммит
7c07a47dfb
|
@ -290,6 +290,16 @@ full support for:
|
|||
|
||||
1. Optimizer State Partitioning (ZeRO stage 1)
|
||||
2. Add Gradient Partitioning (ZeRO stage 2)
|
||||
3. Custom fp16 handling
|
||||
4. A range of fast Cuda-extension-based Optimizers
|
||||
5. ZeRO-Offload
|
||||
|
||||
ZeRO-Offload has its own dedicated paper: `ZeRO-Offload: Democratizing Billion-Scale Model Training
|
||||
<https://arxiv.org/abs/2101.06840>`__.
|
||||
|
||||
DeepSpeed is currently used only for training, as all the currently available features are of no use to inference.
|
||||
|
||||
|
||||
|
||||
Installation
|
||||
=======================================================================================================================
|
||||
|
@ -329,6 +339,11 @@ Unlike, ``torch.distributed.launch`` where you have to specify how many GPUs to
|
|||
full details on how to configure various nodes and GPUs can be found `here
|
||||
<https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node>`__.
|
||||
|
||||
In fact, you can continue using ``-m torch.distributed.launch`` with DeepSpeed as long as you don't need to use
|
||||
``deepspeed`` launcher-specific arguments. Typically if you don't need a multi-node setup you're not required to use
|
||||
the ``deepspeed`` launcher. But since in the DeepSpeed documentation it'll be used everywhere, for consistency we will
|
||||
use it here as well.
|
||||
|
||||
Here is an example of running ``finetune_trainer.py`` under DeepSpeed deploying all available GPUs:
|
||||
|
||||
.. code-block:: bash
|
||||
|
@ -402,12 +417,42 @@ find more details in the discussion below.
|
|||
For a practical usage example of this type of deployment, please, see this `post
|
||||
<https://github.com/huggingface/transformers/issues/8771#issuecomment-759176685>`__.
|
||||
|
||||
Notes:
|
||||
|
||||
- if you need to run on a specific GPU, which is different from GPU 0, you can't use ``CUDA_VISIBLE_DEVICES`` to limit
|
||||
the visible scope of available GPUs. Instead, you have to use the following syntax:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
deepspeed --include localhost:1 ./finetune_trainer.py
|
||||
|
||||
In this example, we tell DeepSpeed to use GPU 1.
|
||||
|
||||
|
||||
Configuration
|
||||
=======================================================================================================================
|
||||
|
||||
For the complete guide to the DeepSpeed configuration options that can be used in its configuration file please refer
|
||||
to the `following documentation <https://www.deepspeed.ai/docs/config-json/>`__.
|
||||
|
||||
You can find dozens of DeepSpeed configuration examples that address various practical needs in `the DeepSpeedExamples
|
||||
repo <https://github.com/microsoft/DeepSpeedExamples>`__:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
git clone https://github.com/microsoft/DeepSpeedExamples
|
||||
cd DeepSpeedExamples
|
||||
find . -name '*json'
|
||||
|
||||
Continuing the code from above, let's say you're looking to configure the Lamb optimizer. So you can search through the
|
||||
example ``.json`` files with:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
grep -i Lamb $(find . -name '*json')
|
||||
|
||||
Some more examples are to be found in the `main repo <https://github.com/microsoft/DeepSpeed>`__ as well.
|
||||
|
||||
While you always have to supply the DeepSpeed configuration file, you can configure the DeepSpeed integration in
|
||||
several ways:
|
||||
|
||||
|
@ -547,7 +592,11 @@ Notes:
|
|||
- ``"overlap_comm": true`` trades off increased GPU RAM usage to lower all-reduce latency. ``overlap_comm`` uses 4.5x
|
||||
the ``allgather_bucket_size`` and ``reduce_bucket_size`` values. So if they are set to 5e8, this requires a 9GB
|
||||
footprint (``5e8 x 2Bytes x 2 x 4.5``). Therefore, if you have a GPU with 8GB or less RAM, to avoid getting
|
||||
OOM-errors you will need to reduce those parameters to about ``2e8``, which would require 3.6GB.
|
||||
OOM-errors you will need to reduce those parameters to about ``2e8``, which would require 3.6GB. You will want to do
|
||||
the same on larger capacity GPU as well, if you're starting to hit OOM.
|
||||
- when reducing these buffers you're trading communication speed to avail more GPU RAM. The smaller the buffer size,
|
||||
the slower the communication, and the more GPU RAM will be available to other tasks. So if a bigger batch size is
|
||||
important, getting a slightly slower training time could be a good trade.
|
||||
|
||||
This section has to be configured exclusively via DeepSpeed configuration - the :class:`~transformers.Trainer` provides
|
||||
no equivalent command line arguments.
|
||||
|
@ -717,6 +766,11 @@ Main DeepSpeed Resources
|
|||
- `API docs <https://deepspeed.readthedocs.io/en/latest/index.html>`__
|
||||
- `Blog posts <https://www.microsoft.com/en-us/research/search/?q=deepspeed>`__
|
||||
|
||||
Papers:
|
||||
|
||||
- `ZeRO: Memory Optimizations Toward Training Trillion Parameter Models <https://arxiv.org/abs/1910.02054>`__
|
||||
- `ZeRO-Offload: Democratizing Billion-Scale Model Training <https://arxiv.org/abs/2101.06840>`__
|
||||
|
||||
Finally, please, remember that, HuggingFace :class:`~transformers.Trainer` only integrates DeepSpeed, therefore if you
|
||||
have any problems or questions with regards to DeepSpeed usage, please, file an issue with `DeepSpeed GitHub
|
||||
<https://github.com/microsoft/DeepSpeed/issues>`__.
|
||||
|
|
Загрузка…
Ссылка в новой задаче