Граф коммитов

1333 Коммитов

Автор SHA1 Сообщение Дата
ChenWenbin cd20a3bbc7
Fix potential memory issues when use deepspeed Z3 (#6726)
I had OOM problem when doing DPO training using zero3. It needs to call
module twice in one training step, and second call is with no_grad().
The problem is caused by two bugs:
1. "__n_available_params", which helps to control fetched parameters,
becomes negative after release_and_reset_all() function.
2. module.ds_grads_remaining becomes negative in backward() if we call
module more than once in one training step.

I tried to create two patches to fix these issues.

---------

Signed-off-by: Wenbin Chen <wenbin.chen@intel.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>
2024-11-21 18:32:03 +00:00
Hyeonseung Lee f515104e95
Removes unnecessary cloning (#6761)
`clone_tensors_for_torch_save()` function:

When the `item.device` is different from `device` input,
`tensor.clone()` is not actually required because `to()` function also
clones the original tensor.


+) I observed memory bloat under following conditions:
* Training a Whisper model w/ `transformers` framework with `ZeRO-0` and
`ZeRO-1` configuration.
* Memory bloating can be observed every time the model state_dict is
cloned using `clone_tensors_for_torch_save()`

After I removed the unnecessary `clone()`, seems like the problem is
solved.

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2024-11-21 17:37:29 +00:00
Max Kovalenko b5709cce66
Enable torch compile on _allgather_params (#6769)
* Previosuly ZerO3 was crashing when trying to compile _allgather_params
* Disabling grad solves the issue
2024-11-21 16:01:13 +00:00
Quentin Gallouédec 83e4364fbd
Use `json_schema_extra` instead of extra keyword in `Field` (#6764)
> Using extra keyword arguments on `Field` is deprecated and will be
removed. Use `json_schema_extra` instead. (Extra keys: 'new_param').
Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2
Migration Guide at https://errors.pydantic.dev/2.9/migration/

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-11-20 01:04:47 +00:00
Logan Adams 2e0c39b55c
Add explicit parameters for torch.load (#6751)
Successor PR to #6094:

> FutureWarning: You are using torch.load with weights_only=False (the
current default value), which uses the default pickle module implicitly.
It is possible to construct malicious pickle data which will execute
arbitrary code during unpickling (See
https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models
for more details). In a future release, the default value for
weights_only will be flipped to True. This limits the functions that
could be executed during unpickling. Arbitrary objects will no longer be
allowed to be loaded via this mode unless they are explicitly
allowlisted by the user via torch.serialization.add_safe_globals. We
recommend you start setting weights_only=True for any use case where you
don't have full control of the loaded file. Please open an issue on
GitHub for any issues related to this experimental feature.

Todo:
- [ ] Update values in non-test files to True where necessary.
2024-11-19 11:09:52 -08:00
Xu Song dd40269426
A faster and more memory-efficient implementation of `zero_to_fp32` (#6658)
It is a faster and more memory-efficient implementation of
`zero_to_fp32`.


The previous version double the memory usage, which cause cpu OOM for
very large models (e.g. llama 405B).

b647fb2470/deepspeed/utils/zero_to_fp32.py (L438-L441)


## How does it work?

1. **Lazy loading**: Load checkpoint with `mmap=True`, thus the weights
are mmaped rather than loading all the storages into memory.
2. **Lazy merge**: `GatheredTensor` contains the mmaped weights and
tensor offset. It is a memory-efficient pseudo tensor. Only when
`tensor.contiguous()` is called, it starts to load related weights to
memory and merge into a single tensor.
3. **Release memory in time**: Save checkpoints shard by shard, and
release the memory once a shard is saved.


Throughout the process, only one shard of tensors are keeped in memory.

## How much benefit in speed and memory ?

Experiments were conducted on a linux host with 1TB of memory. Here is a
detailed comparision
| | world size | peak memory(GB) | elapsed time(h:mm:ss) |

|----------------------|------------|--------------|--------------------|
| llama3-8B(old->new)  | 8          | 90 -> 41 | 0:02:17 -> 0:01:10 |
| llama2-13B(old->new)  | 8        | 146 -> 54 | 0:02:30 -> 0:01:47  |
| llama2-70B(old->new)  | 16        | 789 -> 159 | 0:20:47 -> 0:20:45 |
| qwen1.5-110B(old->new)  | 32       | OOM -> 217 | ? -> 0:34:21 |
| llama3-405B(old->new)  | 192      | OOM -> 262 | ? -> 2:09:59 |



You can reproduce with the following scripts
```sh
# 1. install requirments
apt-get install time
# 2. prepare zero-3 checkpoints
# 3. convert zero to fp32 checkpoints
/usr/bin/time -v python zero_to_fp32.py . output_dir/ --safe_serialization
```

- **memory**: Theoretically, this PR reduces the memory cost from `2M`
to `(1/n)M`, where `M` is the memory cost of the full weights, `n` is
num_shards.
- **speed**: The speed gain mainly comes from avoiding extra tensor
copying. The benifit may be slight.




## Impl history

-
[v1](19712a1c75 (diff-6a2ca3427fa608c387b7351359f98cfc1313be6e960cee86344ff246bf1b8326R441-R447))
: a hf_hub compatible approach.
It has been discarded due to the controversial implementation of
`data_ptr().`
- [v2](https://github.com/microsoft/DeepSpeed/pull/6658/files): a simple
approach with `torch.empty`

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-11-18 20:14:35 +00:00
Olatunji Ruwase fc4e73370d
Add no_sync context manager (#6675)
Fix #1902

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-11-14 18:52:51 +00:00
Joe Mayer b692cdea47
AIO File Offsets (#6641)
Adding the option for a file offset to the read/write functions of AIO &
GDS ops.

---------

Co-authored-by: jomayeri <deepspeed@H100-VM2.shlnn55tgwve1eacvp21ie45dg.jx.internal.cloudapp.net>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-11-12 16:34:17 +00:00
inkcherry 7af3a4beb5
add zero3 ```module_granularity_threshold ``` to zero optimization. (#6649)
This PR adds Z3 coalesced fetch to zero optimization. Currently, some
logic can be reused, but it's difficult to realize that as optimization
choice(I only discovered these logic when trying to implement it).

The benefit of this approach is reducing host overhead(reduce many
hooks) and during the process of recursive fetching parameters
(especially in fine-grained models, such as those with a large number of
moe experts). This is particularly helpful for host-sensitive devices
(such as hpu), where it achieved a 40% performance improvement in our
customer workloads.
FYI @delock @deepcharm

---------

Co-authored-by: Ma, Guokai <guokai.ma@gmail.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2024-11-12 14:25:33 +00:00
Hongwei Chen 73d974ee64
Add data type check for bf16 (#6742)
Add data type check for bf16 to fix #6723
2024-11-12 13:01:31 +00:00
Chengming Zhang fabab197f7
Add Domino code (#6733)
add domino code

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-11-11 23:55:09 +00:00
Logan Adams d2a4718946
Update yapf version (#6721)
This update is needed to support eventually running on ubuntu-24.04 from
GitHub, specifically because the python version is updated to 3.12 and
results in the following error: `ModuleNotFoundError: No module named
'lib2to3'` since that package is deprecated.
2024-11-06 18:57:12 +00:00
Masahiro Tanaka 351569dd4a
Use one param coordinator for both train/inference scenarios (#6662)
The parameter coordinator in ZeRO3 throws a "backward pass is invalid
for module in evaluation mode" error when the training mode is
unexpected, as it expects all modules to be in training mode during the
backward pass. This is an unnecessarily strict restriction.
This PR relaxes the restriction by using a single parameter coordinator
(instead of separate ones for training and evaluation modes) and
resetting the prefetch state before starting a forward pass.

Use of `is_compiling` needs to be fixed after #6663 is merged.

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-11-05 22:53:01 +00:00
Xinyu Lian ff1c54351f
fix memcpy issue on backward for zero-infinity (#6670)
This PR is similar to
[PR#5301](https://github.com/microsoft/DeepSpeed/pull/5301), that
optimizes the D2H time use pinned memory.

Previously, the D2H memcpy will be the bottleneck during the final
backward pass of each iteration for ZeRO-Infinity(offload), as shown in
Trace-1. The new version can eliminate the bottleneck, as shown in
Trace-2.

_Trace-1_
<img width="480" alt="image"
src="https://github.com/user-attachments/assets/891e3770-351b-4e03-8a59-b491bc44d03b">

_Trace-2_
<img width="192" alt="image"
src="https://github.com/user-attachments/assets/f1cf9037-77f8-42a6-adc8-d5c6bacde0aa">

cc @tjruwase

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2024-10-31 10:56:09 -07:00
Yejing-Lai c7f58c899f
Add attribute check to support git-base autotp (#6688)
Git-base model is an image-text model. After supporting the llama3.2
vision model, we set num_kv_heads dynamically.
Git-base only includes vision_config, so we need to add an attribute
check for vision_config/text_config when setting num_kv_heads.

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-10-31 00:48:52 +00:00
xuanhua e4a247ed13
Fix training of pipeline based peft's lora model (#5477)
Hi, guys

I find there is an assert failure when I train huggingface's lora based
model in pipeline style.

Here is the whole steps that I created my model:
1)  Load the pre-trained chatglm-6b model from huggingface, as Model_A
2) Use huggingface's peft's `get_peft_model(...)` and my
`LoraConfig(...)` from Model_A to create the lora model, as Model_B
3)  Create my own pipeline based model Model_C from Model_B

And I run Model_C under 2 3090ti GPUs. And the assertion failure looks
like this:
```text
Traceback (most recent call last):
  File "/home/ubuntu/proj/chatglm-finetuning/train_pipeline.py", line 372, in <module>
    main()
  File "/home/ubuntu/proj/chatglm-finetuning/train_pipeline.py", line 351, in main
    loss = engine.train_batch(data_iter=train_dataloader)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/runtime/pipe/engine.py", line 375, in train_batch
    self._exec_schedule(sched)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/runtime/pipe/engine.py", line 1375, in _exec_schedule
    self._exec_instr(**cmd.kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/runtime/pipe/engine.py", line 276, in _exec_reduce_tied_grads
    dist.all_reduce(grad, group=group)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/comm/comm.py", line 117, in log_wrapper
    return func(*args, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/comm/comm.py", line 496, in all_reduce
    return cdb.all_reduce(tensor, op, group, async_op)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/comm/torch.py", line 159, in all_reduce
    return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 1520, in all_reduce
    _check_single_tensor(tensor, "tensor")
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 463, in _check_single_tensor
    raise RuntimeError(
RuntimeError: Invalid function argument. Expected parameter `tensor` to be of type torch.Tensor.
```

After some debugging, I find out the root cause is that my configuration
of lora (in below) only add extra lora layer(part) in qkv related layers
but not the embedding layer. So the whole embedding layer's parameters
are freezed.
```python
lora_config = LoraConfig(r=8, # copied from finetuning_lora.py
                        lora_alpha=32,
                        target_modules=["query_key_value"],
                        lora_dropout=0.1,
                        bias="none",
                        task_type="CAUSAL_LM",
                        inference_mode=False,
                        )   
```
And in my implementation of pipeline based model, I declared the
embeding layer as a tied-layer. So the whole thing is that there are no
gradients at all for embedding layer, but embedding layer as the tied
layer needs to be synced between two gpus. The value of gradient is None
but is still passed to `all_reduce` operation.

Current, my fix is simple and add a check if this `grad` is None.

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Heyang Qin <heyangqin@microsoft.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
2024-10-29 16:04:35 +00:00
Logan Adams 54903e09eb
Update profiler registration check (#6668)
Resolves #5432.
2024-10-25 22:14:26 +00:00
Masahiro Tanaka 24285d6c73
Add fallback for is_compiling (#6663)
Importing `torch.compiler.is_compiling` causes an error with an older
version of PyTorch.
This PR adds a fallback for `is_compiling` to use an equivalent function
of older PyTorch versions.

This will resolve #6656.

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-10-25 20:47:22 +00:00
inkcherry 5fb71c0a18
sequence parallel for uneven heads (#6392)
In sequence_parallel (Ulysses), the sequence parallel size is
constrained by the requirement to be divisible by the number of heads,
which prevents some models/workloads from setting a specific sequence
parallel size. This PR implements uneven all-to-all heads splitting.

- both support  batch first (b,s,...) and seq_len first(s,b..) layout.
- Added unit tests with numerical checks. Locally also tested with **7
heads with sp=4** and **20 heads with sp=8**, and it passed.

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Ma, Guokai <guokai.ma@gmail.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
2024-10-25 18:26:47 +00:00
Yichen Yan 3d5cf739ea
Fix dynamo issue (#6527)
Dynamo use faketensor to trace tensor ops. In some case, the mechanism
break compiling with deepspeed.

An example could be found at
https://gist.github.com/oraluben/9b8240c2fe482eb4382453d6c97a5f76, to
see issues, install deepspeed==0.14.4 instead of my fork

without this PR, llama cannot be compiled.

Detailed explanation:

1. `ZeROOrderedDict`
dynamo use deepcopy to copy tensors, which will call
`object.__reduce__`. When copying `ZeROOrderedDict`, the default
implementation do not copy its `_parent_module` and will lead to
failure.
2. `param` maybe faketensor and do not have `ds_status` yet, but during
tracing it's ok to just skip the `register_external_parameter`, it
should be done ways before.

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
2024-10-25 00:17:30 +00:00
Lzhang-hub 6e6563d3c8
fix init_device_mesh for torch 2.4 (#6614)
Start torch 2.4, in
[`init_device_mesh()`](de4c2a3b4e/torch/distributed/device_mesh.py (L915))
,device type with a GPU index, such as "cuda:0", is not allowed.


![image](https://github.com/user-attachments/assets/1ddb61bf-8a15-4e0a-9115-a3681d7f19ff)

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Co-authored-by: Masahiro Tanaka <mtanaka@microsoft.com>
2024-10-23 20:29:30 +00:00
Yejing-Lai e06bb518aa
Add attribute check for language_model when replace last linear module (#6650)
Fix module has no attribute 'language_model' issue.

Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
2024-10-23 20:22:59 +00:00
wyooyw b647fb2470
Fix expert grad scaling problem with ZeRO optimizer (#6546)
Fix [#6545]

work:
- expert gradient average: divide edp_world_size -> divide dp_world_size
- unit test: make sure model with different dp/ep has same expert
gradient

---------

Co-authored-by: wangyiou <wangyiou@xiaohongshu.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-10-23 00:08:39 +00:00
Shelly Nahir ce468c3756
add option to disable logger while compiling to avoid graph breaks (#6496)
adding an option to disable calls for logger while compiling to avoid
graph breaks. Here I used an environment variable to determine whether
to activate this option, but it can also be determined using the json
config file or any other way you see fit.

---------

Co-authored-by: snahir <snahir@habana.ai>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
2024-10-15 18:30:42 +00:00
Xu Song bf60fc0ca6
Support safetensors export (#6579)
## Feature

This commit implements the following features:

- [x] support saving checkpoint as safetensors (more commonly used
format)
- [x] support sharding checkpoints (which is important for very large
models)

Most of the codes are borrowed from
https://github.com/huggingface/transformers/blob/v4.45.1/src/transformers/modeling_utils.py#L2490

## Usage

For `pytorch_model.bin` export
```
python zero_to_fp32.py . output_dir/
```

For  `model.safetensors` export
```
python zero_to_fp32.py . output_dir/ --safe_serialization
```

---------

Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-10-15 11:22:31 +00:00
Joe Mayer 85b7469ea0
Add first Step in LR Schedulers (#6597)
Some (not all) of the LR schedulers in runtime were missing the
initialization of the optimizer group lr.

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-10-14 19:31:45 +00:00
diskkid 13c16c9562
Accept btl_tcp_if_include option through launcher_args (#6613)
This patch fixes issue #4460.
When `btl_tcp_if_include` option is provided through `--launcher_args`,
we use the provided option instead of the hardcoded `--mca
btl_tcp_if_include eth0`. Otherwise we use `--mca btl_tcp_if_include
eth0` as the default for compatibility.

Fixes #4460

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2024-10-14 19:26:24 +00:00
Olatunji Ruwase 65ab64481f
Add API for updating ZeRO gradients (#6590) 2024-10-14 17:35:41 +00:00
Masahiro Tanaka 5c4b97f109 apply fp16 autocast only to floating point values 2024-10-11 19:41:10 +00:00
Masahiro Tanaka adec99121b
Add API to get devices of offload states (#6586)
This PR adds an API `deepspeed.runtime.zero.offload_states
get_state_devices`, which gets devices of offload states as suggested in
this
[comment](https://github.com/microsoft/DeepSpeed/pull/6011#issuecomment-2358068777).

We could lift this up to `deepspeed.utils` but would need to resolve a
circular import: User code -> `deepspeed.utils` ->
`deepspeed.utils.offload_states` -> `deepspeed.runtime.zero` ->
`deepspeed.runtime.zero.partition_parameters` -> `deepspeed.utils`

This will require a significant refactoring as long as we have
`OffloadStateTypeEnum` in `deepspeed.runtime.zero`.

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2024-10-10 02:59:26 +00:00
Nir Sonnenschein d7ca3d8373
reduce setting global variables to reduce torch compile graph breaks (#6541)
setting global variables during training will create a graph breaks when
using torch.compile (reading global variables doesn't). this commit
attempts to reduce the setting of global variables in the checkpointing
flows.
there are 2 main uses setting global variables:
1. Share data between functions
2. Establish that this is the first call to the code

For most of the cases the data in the global variables is data that can
be computed on demand or set once in an initial state in a configure
function.
For "check that this is the first run" use case the code was moved to
the configure function.

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-10-10 00:47:44 +00:00
Masahiro Tanaka 7d751ee890
Clean up prefetched parameters (#6557)
Parameters prefetched by ZeRO3 are sometimes not used. This occurs when
the actual sub-module execution differs from previous tracing. As a
result, the state of the allgather handle for such a parameter remains
`INFLIGHT`, causing functions like `empty_partition_cache` to detect it
and throw an error.
This PR resolves the issue by ensuring that communication finishes and
the parameters are freed.

As this issue was mentioned in #6011, this includes the change of the
branch. We need to merge #6011 first.

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2024-10-09 15:23:33 +00:00
gyou2021 474a3288cd
Enabled Qwen2-MoE Tensor Parallelism (TP) inference (#6551)
Modified _replace_module in auto_tp.py :
The modification keeps the layers 'shared_expert_gate' and 'gate' in
qwen2-moe the original type torch.nn.Linear and not changes them into
LinearLayer. In this way, their weights will not be split into multiple
HPU/GPU cards. Then the qwen2-moe can run on multiple HPU/GPU cards.
Since the weights of 'gate' are not split into multiple HPU/GPU cards,
all gather operations are not needed, which may improve performance.

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-10-09 15:23:16 +00:00
Omar Elayan 645639bcf8
Rearrange inference OPS and stop using builder.load (#5490)
This PR mainly handles all places where InferenceBuilder is used to
access any op or a specific implementation for an op.
Instead an op is defined, and its proper implementation is picked inside
and the usage will be transparent to the user.
What was done in the PR:
1) Added missing ops (added a py file with fallback mechanism)
2) Added missing fallback implementations for existing ops
3) removed all usages for builder.load and replaced them with ops
instead.
4) added workspace op and inferenceContext which contains all workspace
related functions and inferenceContext is the python fallback of
inferenceContext in CUDA
5) a small change to softmax_context signature to fit the fallback
signature.

---------

Co-authored-by: Joe Mayer <114769929+jomayeri@users.noreply.github.com>
Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2024-10-09 01:22:28 +00:00
Yichen Yan ca8b1fe945
Handle when `backend` is also in compile_kwargs (#6502)
cc @tohtana

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
2024-10-08 23:38:43 +00:00
Masahiro Tanaka 5cbbff40bd
Fix device selection using CUDA_VISIBLE_DEVICES (#6530)
This PR addresses #5818.
Instead of contiguous numbers based on the device count, this PR uses
device indices in `--include`.

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-10-08 20:41:44 +00:00
Olatunji Ruwase f74ea69abf
Improve DS logging control (#6602)
Disable `steps_per_print` by default.
2024-10-08 18:38:51 +00:00
Yejing-Lai e97b453645
Add llama3.2 vision autotp (#6577)
Llama3.2-11b and llama3.2-90b including vision model and text model,
these two models have different num_kv_heads, so we need to set
num_kv_heads dynamically.

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-10-08 18:16:04 +00:00
Nadav Elyahu 1caf6e8107
add bfloat16 to inference support dtypes (#6528)
to allow running inference tasks using bfloat16

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Logan Adams <loadams@microsoft.com>
2024-09-27 06:11:06 +00:00
Masahiro Tanaka 047bcf6af6
Add APIs to offload states of model, optimizer, and engine (#6011)
This PR adds the following APIs to offload model, optimizer, and engine
states.

```pytyon
def offload_states(self,
                   include: Container[OffloadStateTypeEnum] = None,
                   device: OffloadDeviceEnum = OffloadDeviceEnum.cpu,
                   pin_memory: bool = True,
                   non_blocking: bool = False) -> None:
    """Move the ZeRO optimizer buffers to the specified device.

    Arguments:
        include: Optional. The set of states to offload. If not provided, all states are offloaded.
        device: Optional. The device to move the ZeRO optimizer buffers to.
        pin_memory: Optional. Whether to pin the memory of the offloaded states.
        non_blocking: Optional. Whether to offload the states asynchronously.
...
def offload_states_back(self, non_blocking: bool = False) -> None:
```

Here is the typical usage.
```python
# Offload after forward, backward, and step
model.offload_states()
# Do something requiring a lot of device memory
...
# Load states back to device memory
model.offload_states_back()
```

You can selectively offload states to balance the offloading overhead
and memory saving.
```python
model.offload_states(include=set([OffloadStateTypeEnum.hp_params, OffloadStateTypeEnum.opt_states], device=OffloadDeviceEnum.cpu)
```

Performance (4.3B parameters / 4x A100)
- Environment (4x A100, [benchmark
script](https://gist.github.com/tohtana/05d5faba5068cf839abfc7b1e38b85e4))
- Average Device to Host transfer time: 2.45 GB/s, aggregated: 9.79 GB/s
  - Average Host to Device transfer: 11.05 GB/s, aggregated: 44.19 GB/s
- Mem (allocated by PyTorch)
  - Before offload 18.2GB
  - After offloading 17.7MB
- Time ([benchmark
script](https://github.com/microsoft/DeepSpeedExamples/tree/tohtana/offload_states/training/offload_states),
offloading time/loading time)

python output_table.py 
| |pin_memory=0 non_blocking=0|pin_memory=0 non_blocking=1|pin_memory=1
non_blocking=0|pin_memory=1 non_blocking=1|

|--:|---------------------------|---------------------------|---------------------------|---------------------------|
| 1|4.34 / 3.42 |4.99 / 2.37 |6.5 / 2.42 |6.0 / 2.39 |
| 2|9.9 / 3.28 |5.1 / 2.34 |6.21 / 2.42 |6.25 / 2.45 |
| 3|9.92 / 3.19 |6.71 / 2.35 |6.33 / 2.38 |5.93 / 2.42 |
| 4|9.55 / 2.82 |7.11 / 2.39 |6.9 / 2.38 |6.5 / 2.43 |
| 5|4.4 / 3.35 |6.04 / 2.41 |6.26 / 2.41 |6.32 / 2.47 |
| 6|4.4 / 3.57 |6.58 / 2.42 |6.88 / 2.4 |6.35 / 2.43 |
| 7|9.51 / 3.12 |6.9 / 2.39 |6.9 / 2.39 |6.46 / 2.4 |
| 8|4.77 / 3.64 |6.69 / 2.39 |7.39 / 2.42 |6.56 / 2.46 |
| 9|9.5 / 3.07 |7.18 / 2.42 |6.67 / 2.39 |7.38 / 2.46 |

TODO:
- Enable offloading to a NVMe storage -> NVMe support is non-trivial. I
suggest adding the support in another PR
- [DONE] Discard buffer (and recreate it) instead of offloading. We
don't need to restore the contiguous buffer for reduce.
- [DONE] Check pin_memory improves performance or not

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2024-09-27 05:37:32 +00:00
Nir Sonnenschein ba58682a13
fix errors when setting zero3 leaf modules with torch.compile (#6564)
When setting zero3 leaf modules to a higher level module and running
with torch.compile, there are a few errors from ZeROOrderedDict.

First it doesn't support Deep copy for not having a constructor with no
parameters.

Second, it doesn't check the existence of ds_status attr on param before
accessing the attr.

change contributed by Haifeng Chen

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-09-26 14:55:12 +00:00
Masahiro Tanaka c85c8703bc
Fix gradient accumulation for Z2+offload (#6550)
The ZeRO 1/2 optimizer performs incorrect gradient accumulation in the
path for ZeRO2 + Offloading. This issue is caused by two main reasons:

1) The micro_step_id in the ZeRO 1/2 optimizer is:

- Initialized to 0 in the constructor.
- Reset to -1 during the backward pass.

For example, given a gradient accumulation step of 4, the micro_step_id
changes as follows:

- For the first global step: 1, 2, 3, 4.
- Subsequently: 0, 1, 2, 3.

2) Gradients are copied to the buffer on the first micro step and
accumulated in the buffer during the following micro steps. However, the
current code incorrectly copies gradients at steps that are not at the
accumulation boundary.

This PR aligns the micro_step_id initialization in both the constructor
and the backward pass, and corrects the condition for copying and
accumulating gradients.

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-09-26 13:11:24 +00:00
Olatunji Ruwase a5400974df
DeepNVMe perf tuning (#6560)
Add performance tuning utilities: `ds_nvme_tune` and `ds_io`.  
Update tutorial with tuning section.

---------

Co-authored-by: Ubuntu <jomayeri@microsoft.com>
Co-authored-by: Joe Mayer <114769929+jomayeri@users.noreply.github.com>
2024-09-26 13:07:19 +00:00
Masahiro Tanaka 7622cd9e68
Use msgpack for p2p comm (#6547)
Use msgpack for P2P communication in pipeline engine.

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-09-26 00:34:38 +00:00
Logan Adams 170b46e8b1
Add conditional on torch version for scaled_dot_product_attention (#6517)
Changes from #4724 broke support for torch<2.0 in the flops profiler as
the scaled_dot_product_attention [wasn't
added](https://pytorch.org/docs/2.0/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention)
until a beta version in torch 2.0

Resolved: #5534

Todo:
- [ ] Test this
- [ ] Issue resolution with users.
2024-09-11 23:21:43 +00:00
Olatunji Ruwase 659f6be105
Avoid security issues of subprocess shell (#6498)
Avoid security issues of `shell=True` in subprocess

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-09-11 20:07:06 +00:00
Nadav Elyahu 8fa6b50bfe
Revert "BF16 optimizer: Clear lp grads after updating hp grads in hook" (#6508)
Reverts microsoft/DeepSpeed#5328
After offline discussion with @YangQun1 , we agreed that there is no
memory effect as clear_lp_grads flag triggers zero_() ops which just
zeros buffers and does not free any memory. the outcome is compute
overhead.
2024-09-09 15:27:54 +00:00
Geary.Z fc22d9602d
fix environment variable export bug for MultiNodeRunner (#5878)
In some multi-node environment like SLURM,there are some environment
vars that contain special chars and can trigger errors when being
exported.

For example, there is a var `SLURM_JOB_CPUS_PER_NODE=64(x2)` when
requesting two nodes with 64 cpus using SLURM.
Using `runner.add_export` to export this var will add a command `export
SLURM_JOB_CPUS_PER_NODE=64(x2)` when launching subprocesses, while this
will cause a bash error since `(` is a key word of bash, like:
```
[2024-08-07 16:56:24,651] [INFO] [runner.py:568:main] cmd = pdsh -S -f 1024 -w server22,server27 export PYTHONPATH=/public/home/grzhang/code/CLIP-2;  export SLURM_JOB_CPUS_PER_NODE=64(x2); ...
server22: bash: -c: 行 0: 未预期的符号“(”附近有语法错误
```
This PR simply wrap the environment vars with a pair of `"` to make sure
they are treated as string.

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-09-07 22:37:14 +00:00
Nadav Elyahu 3b09d945ea
fix pipeline eval_batch micro_batches argument for schedule (#6484)
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-09-05 22:07:57 +00:00
Olatunji Ruwase 662a421b05
Safe usage of popen (#6490)
Avoid shell=True security issues with Popen
2024-09-04 21:06:04 +00:00