Co-authored-by: Logan Adams <loadams@microsoft.com>
Co-authored-by: jomayeri <deepspeed@H100-VM2.shlnn55tgwve1eacvp21ie45dg.jx.internal.cloudapp.net>
This commit is contained in:
Olatunji Ruwase 2024-09-04 11:31:31 -04:00 коммит произвёл GitHub
Родитель cfc6ed3722
Коммит 5df12a4a85
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
7 изменённых файлов: 257 добавлений и 4 удалений

Просмотреть файл

@ -76,7 +76,7 @@ repos:
name: check-torchcuda
entry: ./scripts/check-torchcuda.py
language: python
exclude: ^(.github/workflows/|scripts/check-torchcuda.py|docs/_tutorials/accelerator-abstraction-interface.md|accelerator/cuda_accelerator.py|deepspeed/inference/engine.py|deepspeed/model_implementations/transformers/clip_encoder.py|deepspeed/model_implementations/diffusers/vae.py|deepspeed/model_implementations/diffusers/unet.py|op_builder/spatial_inference.py|op_builder/transformer_inference.py|op_builder/builder.py|setup.py|tests/unit/ops/sparse_attention/test_sparse_attention.py)
exclude: ^(.github/workflows/|scripts/check-torchcuda.py|docs/_tutorials/accelerator-abstraction-interface.md|docs/_tutorials/deepnvme.md|accelerator/cuda_accelerator.py|deepspeed/inference/engine.py|deepspeed/model_implementations/transformers/clip_encoder.py|deepspeed/model_implementations/diffusers/vae.py|deepspeed/model_implementations/diffusers/unet.py|op_builder/spatial_inference.py|op_builder/transformer_inference.py|op_builder/builder.py|setup.py|tests/unit/ops/sparse_attention/test_sparse_attention.py)
# Specific deepspeed/ files are excluded for now until we wrap ProcessGroup in deepspeed.comm
- repo: local

Просмотреть файл

@ -17,7 +17,7 @@ this problem, DeepSpeed has created a suite of I/O optimizations collectively ca
DeepNVMe improves the performance and efficiency of I/O-bound DL applications by accelerating I/O operations
and reducing hardware requirements. It achieves this by leveraging storage innovations such as Non-Volatile
Memory Express (NVMe) Solid Storage Devices (SSDs) and NVIDIA Magnum IO<sup>TM</sup> GPUDirect® Storage (GDS). In this
Memory Express (NVMe) Solid State Drives (SSDs) and NVIDIA Magnum IO<sup>TM</sup> GPUDirect® Storage (GDS). In this
blog we show the benefits of DeepNVMe using microbenchmarks and an inference application. In experiments
conducted on an Azure NC96ads\_A100\_v4 VM, we observed that DeepNVMe saturates available NVMe bandwidth for
data transfers with GPU or CPU memory, achieving up to 10GB/sec reads and 5 GB/secs writes.

Просмотреть файл

@ -55,6 +55,10 @@ lnav:
url: /getting-started/
- title: 'Getting started on Azure'
url: /tutorials/azure/
- title: 'Accelerator Abstraction'
url: /tutorials/accelerator-abstraction-interface/
- title: 'Accelerator Setup Guides'
url: /tutorials/accelerator-setup-guide/
- title: 'Automatic Tensor Parallelism'
url: /tutorials/automatic-tensor-parallelism/
- title: 'Autotuning'
@ -69,6 +73,8 @@ lnav:
url: /tutorials/curriculum-learning/
- title: 'Data Efficiency'
url: /tutorials/data-efficiency/
- title: 'DeepNVMe'
url: /tutorials/deepnvme/
- title: 'DS4Sci_EvoformerAttention'
url: /tutorials/ds4sci_evoformerattention/
- title: 'Flops Profiler'

Просмотреть файл

@ -1,6 +1,6 @@
---
title: DeepSpeed Accelerator Abstraction Interface
tags: getting-started
tags: getting-started training accelerator
---
# Contents

Просмотреть файл

@ -1,6 +1,6 @@
---
title: DeepSpeed Accelerator Setup Guides
tags: getting-started
tags: getting-started training accelerator
---
# Contents

247
docs/_tutorials/deepnvme.md Normal file
Просмотреть файл

@ -0,0 +1,247 @@
---
title: "DeepNVMe"
tags: training inference IO large-model
---
This tutorial will show how to use [DeepNVMe](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-gds/README.md) for data transfers between persistent storage and tensors residing in host or device memory. DeepNVMe improves the performance and efficiency of I/O operations in Deep Learning applications through powerful optimizations built on Non-Volatile Memory Express (NVMe) Solid State Drives (SSDs), Linux Asynchronous I/O (`libaio`), and NVIDIA Magnum IO<sup>TM</sup> GPUDirect® Storage (GDS).
## Requirements
Ensure your environment is properly configured to use DeepNVMe. First, you need to install DeepSpeed version >= [0.15.0](https://github.com/microsoft/DeepSpeed/releases/tag/v0.15.0). Next, ensure that the DeepNVMe operators are available in the DeepSpeed installation. The `async_io` operator is required for any DeepNVMe functionality, while the `gds` operator is required only for GDS functionality. You can confirm availability of each operator by inspecting the output of `ds_report` to check that compatible status is <span style="color:green">[OKAY]</span>. Below is a snippet of `ds_report` output confirming the availability of both `async_io` and `gds` operators.
![deepnvme_ops_report](/assets/images/deepnvme_ops_report.png)
If `async_io` operator is unavailable, you will need to install the appropriate `libaio` library binaries for your Linux flavor. For example, Ubuntu users will need to run `apt install libaio-dev`. In general, you should carefully inspect `ds_report` output for helpful tips such as the following:
```bash
[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
[WARNING] async_io: please install the libaio-dev package with apt
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
```
To enable `gds` operator, you will need to install NVIDIA GDS by consulting the appropriate guide for [bare-metal systems](https://docs.nvidia.com/gpudirect-storage/troubleshooting-guide/index.html) or Azure VMs (coming soon).
## Creating DeepNVMe Handles
DeepNVMe functionality can be accessed through two abstractions: `aio_handle` and `gds_handle`. The `aio_handle` is usable on both host and device tensors. while `gds_handle` works only on CUDA tensors, but is more efficient. The first step to use DeepNVMe is to create a desired handle. `aio_handle` requires `async_io` operator, while `gds_handle` requires both `async_io` and `gds` operators. The following snippets illustrate `aio_handle` and `gds_handle` creation respectively.
```python
### Create aio_handle
from deepspeed.ops.op_builder import AsyncIOBuilder
aio_handle = AsyncIOBuilder().load().aio_handle()
```
```python
### Create gds_handle
from deepspeed.ops.op_builder import GDSBuilder
gds_handle = GDSBuilder().load().gds_handle()
```
For simplicity, the above examples illustrate handle creation using default parameters. We expect that handles created with default parameters to provide good performance in most environments. However, you can see [below](#advanced-handle-creation) for advanced handle creation.
## Using DeepNVMe Handles
`aio_handle` and `gds_handle` provide identical APIs for storing tensors to files or loading tensors from files. A common feature of these APIs is that they take a tensor and a file path as arguments for the desired I/O operation. For best performance, pinned device or host tensors should be used for I/O operations (see [here](#pinned-tensors) for details). For brevity, this tutorial will use `aio_handle` for illustration, but keep in mind that `gds_handle` works similarly.
You can see the available APIs in a Python shell via tab completion on an `aio_handle` object . This is illustrated using tab completion of `h.`.
```bash
>python
Python 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from deepspeed.ops.op_builder import AsyncIOBuilder
>>> h = AsyncIOBuilder().load().aio_handle()
>>> h.
h.async_pread( h.free_cpu_locked_tensor( h.get_overlap_events( h.get_single_submit( h.new_cpu_locked_tensor( h.pwrite( h.sync_pread( h.wait(
h.async_pwrite( h.get_block_size( h.get_queue_depth( h.get_thread_count( h.pread( h.read( h.sync_pwrite( h.write(
```
The APIs of interest for performing I/O operations are those named with `pread` and `pwrite` substrings. For brevity, we will focus on the file write APIs, namely `sync_pwrite`, `async_pwrite`, and `pwrite`. We will discuss only `sync_pwrite` and `async_pwrite` below because they are specializations of `pwrite`.
### Blocking File Write
`sync_pwrite` provides the standard blocking semantics of Python file write. The example below illustrates using `sync_pwrite` to store a 1GB CUDA tensor to a local NVMe file.
```bash
>>> import os
>>> os.path.isfile('/local_nvme/test_1GB.pt')
False
>>> import torch
>>> t=torch.empty(1024**3, dtype=torch.uint8).cuda()
>>> from deepspeed.ops.op_builder import AsyncIOBuilder
>>> h = AsyncIOBuilder().load().aio_handle()
>>> h.sync_pwrite(t,'/local_nvme/test_1GB.pt')
>>> os.path.isfile('/local_nvme/test_1GB.pt')
True
>>> os.path.getsize('/local_nvme/test_1GB.pt')
1073741824
```
### Non-Blocking File Write
An important DeepNVMe optimization is the non-blocking I/O semantics which enables Python threads to overlap computations with I/O operations. `async_pwrite` provides the non-blocking semantics for file writes. The Python thread can later use `wait()` to synchronize with the I/O operation. `async_write` can also be used to submit multiple back-to-back non-blocking I/O operations, of which can then be later blocked on using a single `wait()`. The example below illustrates using `async_pwrite` to store a 1GB CUDA tensor to a local NVMe file.
```bash
>>> import os
>>> os.path.isfile('/local_nvme/test_1GB.pt')
False
>>> import torch
>>> t=torch.empty(1024**3, dtype=torch.uint8).cuda()
>>> from deepspeed.ops.op_builder import AsyncIOBuilder
>>> h = AsyncIOBuilder().load().aio_handle()
>>> h.async_pwrite(t,'/local_nvme/test_1GB.pt')
>>> h.wait()
1
>>> os.path.isfile('/local_nvme/test_1GB.pt')
True
>>> os.path.getsize('/local_nvme/test_1GB.pt')
1073741824
```
<span style="color:red">Warning for non-blocking I/O operations:</span> To avoid data races and corruptions, `.wait()` must be carefully used to serialize the writing of source tensors, and the reading of destination tensors. For example, the following update of `t` during a non-blocking file write is unsafe and could corrupt `/local_nvme/test_1GB.pt`.
```bash
>>> t=torch.empty(1024**3, dtype=torch.uint8).cuda()
>>> from deepspeed.ops.op_builder import AsyncIOBuilder
>>> h = AsyncIOBuilder().load().aio_handle()
>>> h.async_pwrite(t,'/local_nvme/test_1GB.pt')
>>> t += 1 # <--- Data race; avoid by preceding with `h.wait()`
```
Similar safety problems apply to reading the destination tensor of a non-blocking file read without `.wait()` synchronization.
### Parallel File Write
An important DeepNVMe optimization is the ability to parallelize individual I/O operations. This optimization is enabled by specifying the desired parallelism degree when constructing a DeepNVMe handle. Subsequent I/O operations with that handle are automatically parallelized over the requested number of host or device threads, as appropriate. I/O parallelism is composable with either the blocking or non-blocking I/O APIs. The example below illustrates 4-way parallelism of a file write using `async_pwrite`. Note the use of `num_threads` argument to specify the desired parallelism degree in handle creation.
```bash
>>> import os
>>> os.path.isfile('/local_nvme/test_1GB.pt')
False
>>> import torch
>>> t=torch.empty(1024**3, dtype=torch.uint8).cuda()
>>> from deepspeed.ops.op_builder import AsyncIOBuilder
>>> h = AsyncIOBuilder().load().aio_handle(num_threads=4)
>>> h.async_pwrite(t,'/local_nvme/test_1GB.pt')
>>> h.wait()
1
>>> os.path.isfile('/local_nvme/test_1GB.pt')
True
>>> os.path.getsize('/local_nvme/test_1GB.pt')
1073741824
```
### Pinned Tensors
A key part of DeepNVMe optimizations is using direct memory access (DMA) for I/O operations, which requires that the host or device tensor be pinned. To pin host tensors, you can use mechanisms provided by [Pytorch](https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html) or [DeepSpeed Accelerators](/tutorials/accelerator-abstraction-interface/#tensor-operations). The following example illustrates writing a pinned CPU tensor to a local NVMe file.
```bash
>>> import os
>>> os.path.isfile('/local_nvme/test_1GB.pt')
False
>>> import torch
>>> t=torch.empty(1024**3, dtype=torch.uint8).pin_memory()
>>> from deepspeed.ops.op_builder import AsyncIOBuilder
>>> h = AsyncIOBuilder().load().aio_handle()
>>> h.async_pwrite(t,'/local_nvme/test_1GB.pt')
>>> h.wait()
1
>>> os.path.isfile('/local_nvme/test_1GB.pt')
True
>>> os.path.getsize('/local_nvme/test_1GB.pt')
1073741824
```
On the other hand,`gds_handle` provides `new_pinned_device_tensor()` and `pin_device_tensor()` functions for pinning CUDA tensors. The following example illustrates writing a pinned CUDA tensor to a local NVMe file.
```bash
>>> import os
>>> os.path.isfile('/local_nvme/test_1GB.pt')
False
>>> import torch
>>> t=torch.empty(1024**3, dtype=torch.uint8).cuda()
>>> from deepspeed.ops.op_builder import GDSBuilder
>>> h = GDSBuilder().load().gds_handle()
>>> h.pin_device_tensor(t)
>>> h.async_pwrite(t,'/local_nvme/test_1GB.pt')
>>> h.wait()
1
>>> os.path.isfile('/local_nvme/test_1GB.pt')
True
>>> os.path.getsize('/local_nvme/test_1GB.pt')
1073741824
>>> h.unpin_device_tensor(t)
```
## Putting it together
We hope that the above material helps you to get started with DeepNVMe. You can also use the following links to see DeepNVMe usage in real-world Deep Learning applications.
1. [Parameter swapper](https://github.com/microsoft/DeepSpeed/blob/9b7fc5452471392b0f58844219fcfdd14a9cdc77/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py#L111-L117) in [ZeRO-Inference](https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md) and [ZeRO-Infinity](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/).
2. [Optimizer swapper](https://github.com/microsoft/DeepSpeed/blob/9b7fc5452471392b0f58844219fcfdd14a9cdc77/deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py#L36-L38) in [ZeRO-Infinity](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/).
3. [Gradient swapper](https://github.com/microsoft/DeepSpeed/blob/9b7fc5452471392b0f58844219fcfdd14a9cdc77/deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py#L41-L43) in [ZeRO-Infinity](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/).
4. Simple file read and write [operations](https://github.com/microsoft/DeepSpeedExamples/blob/master/deepnvme/file_access/README.md).
<!-- 1. ZeRO-Inference: used for [parameter offloading](https://github.com/microsoft/DeepSpeed/blob/9b7fc5452471392b0f58844219fcfdd14a9cdc77/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py#L111-L117).
2. [ZeRO-Infinity](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/): used for offloading [parameters](https://github.com/microsoft/DeepSpeed/blob/9b7fc5452471392b0f58844219fcfdd14a9cdc77/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py#L111-L117), [gradients](https://github.com/microsoft/DeepSpeed/blob/9b7fc5452471392b0f58844219fcfdd14a9cdc77/deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py#L41-L43), and [optimizer](https://github.com/microsoft/DeepSpeed/blob/9b7fc5452471392b0f58844219fcfdd14a9cdc77/deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py#L36-L38).
3. Simple file read and write [operations](https://github.com/microsoft/DeepSpeedExamples/blob/master/deepnvme/file_access/README.md). -->
## Acknowledgements
This tutorial has been significantly improved by feedback from [Guanhua Wang](https://github.com/GuanhuaWang), [Masahiro Tanaka](https://github.com/tohtana), and [Stas Bekman](https://github.com/stas00).
## Appendix
### Advanced Handle Creation
Achieving peak I/O performance with DeepNVMe requires careful configuration of handle creation. In particular, the parameters of `aio_handle` and `gds_handle` constructors are performance-critical because they determine how efficiently DeepNVMe interacts with the underlying storage subsystem (i.e., `libaio`, GDS, and SSD). For convenience we make it possible to create handles using default parameter values which will provide decent performance in most scenarios. However, squeezing out every available performance in your environment will likely require tuning the constructor parameters, namely `block_size`, `queue_depth`, `single_submit`, `overlap_events`, and `num_threads`. The `aio_handle` constructor parameters and default values are illustrated below:
```bash
>>> from deepspeed.ops.op_builder import AsyncIOBuilder
>>> help(AsyncIOBuilder().load().aio_handle())
Help on aio_handle in module async_io object:
class aio_handle(pybind11_builtins.pybind11_object)
| Method resolution order:
| aio_handle
| pybind11_builtins.pybind11_object
| builtins.object
|
| Methods defined here:
|
| __init__(...)
| __init__(self: async_io.aio_handle, block_size: int = 1048576, queue_depth: int = 128, single_submit: bool = False, overlap_events: bool = False, num_threads: int = 1) -> None
|
| AIO handle constructor
```
### DeepNVMe APIs
For convenience, we provide listing and brief descriptions of the DeepNVMe APIs.
#### General I/O APIs
The following functions are used for I/O operations with both `aio_handle` and `gds_handle`.
Function | Description |
|---|---|
async_pread | Non-blocking file read into tensor |
sync_pread | Blocking file read into tensor |
pread | File read with blocking and non-blocking options |
async_pwrite | Non-blocking file write from tensor |
sync_pwrite | Blocking file write from tensor |
pwrite | File write with blocking and non-blocking options |
wait | Wait for non-blocking I/O operations to complete |
#### GDS-specific APIs
The following functions are available only for `gds_handle`
Function | Description
|---|---|
new_pinned_device_tensor | Allocate and pin a device tensor |
free_pinned_device_tensor | Unpin and free a device tensor |
pin_device_tensor | Pin a device tensor |
unpin_device_tensor | unpin a device tensor |
#### Handle Settings APIs
The following APIs can be used to probe handle configuration.
Function | Description
|---|---|
get_queue_depth | Return queue depth setting |
get_single_submit | Return whether single_submit is enabled |
get_thread_count | Return I/O parallelism degree |
get_block_size | Return I/O block size setting |
get_overlap_events | Return whether overlap_event is enabled |

Двоичные данные
docs/assets/images/deepnvme_ops_report.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 8.8 KiB