Граф коммитов

171 Коммитов

Автор SHA1 Сообщение Дата
Logan Adams 2e0c39b55c
Add explicit parameters for torch.load (#6751)
Successor PR to #6094:

> FutureWarning: You are using torch.load with weights_only=False (the
current default value), which uses the default pickle module implicitly.
It is possible to construct malicious pickle data which will execute
arbitrary code during unpickling (See
https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models
for more details). In a future release, the default value for
weights_only will be flipped to True. This limits the functions that
could be executed during unpickling. Arbitrary objects will no longer be
allowed to be loaded via this mode unless they are explicitly
allowlisted by the user via torch.serialization.add_safe_globals. We
recommend you start setting weights_only=True for any use case where you
don't have full control of the loaded file. Please open an issue on
GitHub for any issues related to this experimental feature.

Todo:
- [ ] Update values in non-test files to True where necessary.
2024-11-19 11:09:52 -08:00
Logan Adams d2a4718946
Update yapf version (#6721)
This update is needed to support eventually running on ubuntu-24.04 from
GitHub, specifically because the python version is updated to 3.12 and
results in the following error: `ModuleNotFoundError: No module named
'lib2to3'` since that package is deprecated.
2024-11-06 18:57:12 +00:00
Yejing-Lai c7f58c899f
Add attribute check to support git-base autotp (#6688)
Git-base model is an image-text model. After supporting the llama3.2
vision model, we set num_kv_heads dynamically.
Git-base only includes vision_config, so we need to add an attribute
check for vision_config/text_config when setting num_kv_heads.

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-10-31 00:48:52 +00:00
inkcherry 5fb71c0a18
sequence parallel for uneven heads (#6392)
In sequence_parallel (Ulysses), the sequence parallel size is
constrained by the requirement to be divisible by the number of heads,
which prevents some models/workloads from setting a specific sequence
parallel size. This PR implements uneven all-to-all heads splitting.

- both support  batch first (b,s,...) and seq_len first(s,b..) layout.
- Added unit tests with numerical checks. Locally also tested with **7
heads with sp=4** and **20 heads with sp=8**, and it passed.

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Ma, Guokai <guokai.ma@gmail.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
2024-10-25 18:26:47 +00:00
Yejing-Lai e06bb518aa
Add attribute check for language_model when replace last linear module (#6650)
Fix module has no attribute 'language_model' issue.

Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
2024-10-23 20:22:59 +00:00
gyou2021 474a3288cd
Enabled Qwen2-MoE Tensor Parallelism (TP) inference (#6551)
Modified _replace_module in auto_tp.py :
The modification keeps the layers 'shared_expert_gate' and 'gate' in
qwen2-moe the original type torch.nn.Linear and not changes them into
LinearLayer. In this way, their weights will not be split into multiple
HPU/GPU cards. Then the qwen2-moe can run on multiple HPU/GPU cards.
Since the weights of 'gate' are not split into multiple HPU/GPU cards,
all gather operations are not needed, which may improve performance.

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-10-09 15:23:16 +00:00
Yejing-Lai e97b453645
Add llama3.2 vision autotp (#6577)
Llama3.2-11b and llama3.2-90b including vision model and text model,
these two models have different num_kv_heads, so we need to set
num_kv_heads dynamically.

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-10-08 18:16:04 +00:00
Yejing-Lai 0d3bb77b33
Add chatglm2 & chatglm3 autotp (#5540)
This PR aims to enable chatglm2 & chatglm3 autotp. Similar to the phi3,
this model uses the chunk MLP layer, so we adjust the weight order by
'shard_mlp_chunk' func. Please kindly review~ Thanks!

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Heyang Qin <heyangqin@microsoft.com>
2024-07-23 02:11:21 +00:00
Yang, Bo 9fa4c42443
fix: quantization with DeepSpeed HE (#5624)
When the model is quantized, the hidden sizes cannot be determined from
`ds_shape` and `shape`, because they are 1 dimensional. This PR fixes
the bug by determining hidden sizes from `in_features` and
`out_features`.

This PR fixes #5398

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-07-23 00:59:51 +00:00
Omar Elayan 830d0c0a10
[INF] Add Qwen2RMSNorm to loaded layers in auto_tp (#5786)
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-07-23 00:54:19 +00:00
Yejing-Lai a07a3c5d22
Fix phi3 mini 128k load error (#5765)
Fix phi3 mini 128k load error.

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-07-15 16:57:45 +00:00
Yejing-Lai 8ea995ee1f
enable yuan autotp & add conv tp (#5428)
This PR aims to enable yuan model autotp and add conv tp. 

Yuan model used shared qk. 
For example:
q_linear_out = [q1, q2, q3, q4, q5, ... , q16]
k_linear_out = [k1, k2, k3, k4, k5, ... , k16]

after share qk:
TP=1:
q' = [q1,q2,q3,q4,  q9,q10,q11,q12,  k1,k2 k3,k4,  k9,k10,k11,k12]
k' = [q5,q6,q7,q8,  q13,q14,q15,q16,  k5,k6,k7,k8,  k13,k14,k15,k16]
v' = [v1,v2,v3,v4,  v5,v6,v7,v8, v9,v10,v11,v12, v13,v14,v15,v16]

TP=2:
rank0:
q'_0 = [q1,q2,q3,q4, k1,k2 k3,k4]
k'_0 = [q5,q6,q7,q8, k5,k6,k7,k8]
v'_0 = [v1,v2,v3,v4, v5,v6,v7,v8] -> v'_0 is error! Expect value is:
[v1,v2,v3,v4, v9,v10,v11,v12]
rank1:
q'_1 = [q9,q10,q11,q12, k9,k10,k11,k12]
k'_1 = [q13,q14,q15,q16, k13,k14,k15,k16]
v'_1 = [v9,v10,v11,v12, v13,v14,v15,v16] -> v'_1 is error! Expect value
is: [v5,v6,v7,v8, v13,v14,v15,v16]

To avoid modifying the modeling code. We adjust the value and oproj
weight to fit this qk type.

We also added the conv tp to support some models that including the
heavy conv calculation. It is similar to the linear tp policy.
if  not last_conv_layer:

- 1. Divide the conv weight to each rank along the output channel
dimension.
-  2. To apply conv2d.

else:

- 1. Divide the conv weight to each rank along the input channel
dimension.
-  2. Apply conv2d.
-  3. Use allreduce to add outputs.

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-06-18 13:30:46 -07:00
Logan Adams 62ca317829
Switch from double quotes to match single quotes (#5530) 2024-05-13 20:20:21 -07:00
Yejing-Lai 3a7f3aa849
enable phi2 autotp (#5436)
This PR aims to enable phi2 model autotp.

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2024-05-13 20:10:53 +00:00
Yejing-Lai 3dd7ccff81
enable phi3_mini autotp (#5501)
This PR aims to enable phi3 mini autotp.

Phi3 mini uses chunk MLP. We adjust this linear layer weight order to
support this model.

Please kindly review~ Thanks!

---------

Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>
2024-05-08 22:04:02 +00:00
Yejing-Lai 8d98e17140
Enable mixtral 8x7b autotp (#5257)
This PR aims to enable mixtral 8x7b (MoE model) autotp.

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-03-27 15:27:57 +00:00
Max Kovalenko d9e12d3a68
Fix attention mask handling in the Hybrid Engine Bloom flow (#5101)
The Bloom flow in Hybrid Engine applies the same transformation of the
input mask which is already performed earlier by the transformers
BloomModel::forward.

This results in the non-convergence of scores, specifically in Deepspeed
Chat on different accelerators, including CUDA and HPU.

The fix removes redundant mask transformation and application, producing
correct convergence.

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>
2024-03-12 23:50:24 +00:00
Yejing-Lai bc0d24651d
fix fused_qkv model accuracy issue (#5217)
Fused_qkv model can not correctly choose the fused_qkv type. Need to
update the module_name_matches.

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2024-03-05 22:31:46 +00:00
Lev Kurilenko e212845e39
Add backwards compatibility w/ older versions of diffusers (<0.25.0) (#5083)
This PR adds backwards compatibility for older versions of `diffusers`
(`<0.25.0`) by updating the `vae` container import logic to account for
changes between the various versions.
2024-02-06 01:37:44 +00:00
Michael Wyatt a049370c0c
Update import for changes to latest diffusers (#5065) 2024-02-02 15:38:44 -08:00
Polisetty V R K Jyothendra Varma 567f97b264
load linear layer weight with given dtype (#4044)
bf16 inference fails due to data type mismatch as half is default value

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2024-02-02 15:37:54 -08:00
Yejing-Lai 62afafe812
Update falcon fused type order (#5007)
The selection of fused type depends on the order of fused_type_dict.
If put “DecoderLayer” in front of “FalconDecoderLayer”, Falcon will
still choose glmtype incorrectly, so need to put “DecoderLayer at” the
last position of fused_type_dict.

---------

Co-authored-by: Michael Wyatt <mrwyattii@gmail.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-01-26 06:42:31 +00:00
Yejing-Lai e62a47e2e8
Fix T5 and mistral model meta data error (#4958)
Fix 'NotImplementedError: Cannot copy out of meta tensor; no data!',
when loading T5 and mistral from device meta.

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-01-19 18:57:27 +00:00
Yejing-Lai 29417ab55f
fix uneven issue & add balance autotp (#4697)
This PR aims to balance the shard size of each worker as even as
possible.
1. We refactor the tp_shard logic that can make AutoTP work when
split_shape % num_kv_heads != 0.
2. When num_kv_heads is defined, the attention module relies on it to
sharding, but the mlp and lm_head modules can use near even division to
get more balance shard. It will get better performance.

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>
2024-01-12 22:56:07 +00:00
inkcherry ee7db48373
autoTP for Qwen (#4902)
Enabled autoTP for the Qwen model, added some module matching, and
adjusted TP-related variables. Verification was conducted on Qwen-1_8B
and Qwen-72B-chat.

---------

Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2024-01-11 05:13:22 +00:00
Yejing-Lai 16c265c0ce
fix falcon-40b accuracy issue (#4895)
This [PR](https://github.com/microsoft/DeepSpeed/pull/4721) added the
"DecoderLayer":glmtype. It will cause the Falcon model to choose
"glmtype" fused_qkv_type. Falcon model (including Falcondecoderlayer)
needs to choose 'bloomtype' explicitly.

Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
2024-01-09 16:45:54 -08:00
Lev Kurilenko 75db3d7da7
Fix SD workflow to work with latest diffusers version (#4918)
This PR fixes the Stable Diffusion workflow to work with the latest
`diffusers` version (`0.25.0`).

Fixes #4911.

Manual test:
https://github.com/microsoft/DeepSpeed/actions/runs/7452977322
2024-01-08 22:22:22 +00:00
Yejing-Lai 1787673edc
fix num_kv_heads sharding in uneven autoTP for Falcon-40b (#4712)
Falcon-40b will fail on uneven autotp. Need to add 'num_kv_heads' in the
kv_head_names list.

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
2024-01-05 21:29:00 +00:00
Yejing-Lai 85132adc31
enable starcode((kv_head=1)) autotp (#4896)
Hi, This PR is aim to enable starcode(kv_head=1) autotp. Please kindly
review. Thanks~

Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
2024-01-05 20:54:15 +00:00
Wang, Yi c8c57b8c24
add sharded loading for safetensors in AutoTP (#4854)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
2024-01-05 12:27:52 -08:00
baodi c20f6fa4e0
support baichuan model: (#4721)
* fix Baichuan meta data error
* add BaichuanLayer and DecoderLayer to glmtype when prepare tp fused
qkvw
   * add get_alibi_mask function for Baichuan to enable TP

---------

Co-authored-by: Lai, Yejing <yejing.lai@intel.com>
Co-authored-by: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
2023-12-18 10:45:01 -08:00
Omar Elayan 4c2cac0340
Inference changes for incorporating meta loading checkpoint (#4692)
1. In both files, the same logic was done that if when it is meta no
need to move the tensors to the device.
2. Deletion of an unused member of the class

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2023-12-18 10:29:04 -08:00
baodi faa00b1373
fix falcon model load from_config meta_data error (#4783) 2023-12-15 15:16:33 -08:00
RyanInnerpeace 7b818ee961
improve the way to determine whether a variable is None (#4782)
refactor: improve the way to decide whether a variable is None
fix: type mismatch for judging if current accelerator is in
SUPPORTED_ACCELERATOR_LIST

---------

Co-authored-by: ryan <ruanzhixiang1@huawei.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2023-12-08 20:28:48 +00:00
Wang, Yi 29f840fd1a
fix autoTP issue for mpt (trust_remote_code=True) (#4787)
to fix https://github.com/microsoft/DeepSpeed/issues/4774

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2023-12-08 18:56:27 +00:00
Dino Chen 6ea44d02c6
fix num_kv_heads sharding in autoTP for the new in-repo Falcon-40B (#4654)
to be compatible with the latest Falcon-40B's `num_kv_heads` in
4a70170c21

![image](https://github.com/microsoft/DeepSpeed/assets/5948851/d20aa6f2-b9af-4104-b9d3-8ba1ab588a6e)

error message like:

![image](https://github.com/microsoft/DeepSpeed/assets/5948851/06ef6dd2-25d5-4b51-8789-36e1b3f94a32)

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>
2023-11-10 18:55:18 +00:00
Ma, Guokai f15cccfa0c
[AutoTP] Make AutoTP work when num_heads not divisible by number of workers (#4011)
* allow number of heads not divisible by number of ranks

* get num_heads from model config, more robust

* simplify logic where num_head itself is sharded

* name tweaks

* make code more robust where num_attention_heads may not be defined in model_config

* support num_key_value_heads < num_attention_heads which is used by llama2

* add test for 5 ranks

* change odd rank # to 3 to avoid test skip

* add get_shard_size function

* modify sharding mechanism according to latest auto TP

* fix accuracy issue

* fix format

* skip tests with fusedqkv

* remove skip of fusedqkv tests

* skip test fusedqkv with odd number of ranks

* support model with n_heads in model_config

* fix TestInjectionPolicy::test[fp32-t5]

* fix uneven_heads on some fusedqkv types (#12)

* odd support fusedqkv

* fix format and clear text

* better fix when activation size cannot be divided by number of heads

* move tp_shard.py under module_inject

* Add get_num_kv_heads in tp_shard.py

* Refine according to comments

* remove old comment

* fix bug in getting num_kv_heads

* support uneven sharding of lm_head tensor parallel

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Molly Smith <112220543+molly-smith@users.noreply.github.com>
Co-authored-by: mzl <mingzhi.liu@intel.com>
Co-authored-by: Michael Wyatt <mrwyattii@gmail.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
2023-10-25 21:35:17 +00:00
Ilya Vologin beed962c25
[Bug fix] Add rope_theta for llama config (#4480)
* Add rope_theta for llama config

* Add rope_theta to bias_add_transform_0213

* Fix CI problems

* Add rope_theta to linear layer

---------

Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>
2023-10-19 16:48:29 +00:00
Yejing-Lai 6763e2de61
add lm_head and embed_out tensor parallel (#3962)
* add lm_head and embed_out tensor parallel

* fix load lm_head.weight name issue

* replace all_reduce with inference_all_reduce

* refactor lm_head tensor parallel

---------

Co-authored-by: Chen, Zhenhuan <zhenhuan.chen@intel.com>
2023-10-09 10:24:16 +00:00
Wang, Yi d72edb3b0d
fix lm head overriden issue, move it from checkpoint in-loop loading to out loop (#4206)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2023-10-05 22:31:57 +00:00
Yejing-Lai 7220e7f8f7
fix cpu loading model partition OOM (#4353)
* fix cpu loading model partition OOM

* clean up

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2023-09-28 23:45:51 +00:00
Elsa Granger 4fc2c8e7d5
Fix llama meta tensor loading in AutoTP and kernel injected inference (#3608)
* Adapt to Llama when using meta tensor to load

* Fix gated mlp parameter mp

* Re-enable meta tensor for kernel injection
Fix layer params loading in meta tensor

* Revert mlp_inter_mp for gated mlp as it is fixed

* Monkey patch for fixing llama output

* Fix formatting

* Add comment

---------

Co-authored-by: Lev Kurilenko <lekurile@microsoft.com>
Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>
2023-09-20 20:30:38 +00:00
Reza Yazdani 468882fb68
Add the policy to run llama model from the official repo (#4313)
* Add the llama2 support from the official llama repo

* add back commented function

* add new policy & implementation for llama2

* add some changes to inject/run the 70b llama model

* remove debugging code

* remove more debugging code

* formatting

* use num_kv only when it has positive value

* use the num_kv param only if  it is positive

* fix syntax and format errors.

* fix an issue with the float32 transform kernel

---------

Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com>
2023-09-19 16:57:55 +00:00
WRH 367d6f9cec
Support InternLM (#4137)
* correct inference with some debug codes.

* remove prints

* update transformer import set_qkv and format

* support some lora abstract method

* fix attn_ob

* some debug

* leave orig layer set by user

* remove debugs

* move attn ob to mlp module

* move import transformer

* init orig class only once

* remove copyright

---------

Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>
2023-09-18 18:11:12 +00:00
Ammar Ahmad Awan b9d719a6d3
Pass base_dir to model files can be loaded for auto-tp/meta-tensor. (#4348) 2023-09-15 21:26:56 +00:00
stephen youn ffd82bb048
added a bert-model check for triton (#4266)
Co-authored-by: Stephen Youn <styoun@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2023-09-11 20:10:01 +00:00
Satpal Singh Rathore 430510bfce
Checks for user injection policy (#3052)
* check injection policy

* transformers v4

* move check_inference_tuple

* user injection policy check in infer engine

* fix pre-commit format

* fix formatting

* fix clang format

---------

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>
2023-09-03 18:58:57 +00:00
Dino Chen 6cbf666131
fix MegatronLayerPolicy to be compatible with the newest ParallelTransformerLayer (#4236)
Co-authored-by: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com>
2023-08-30 23:28:43 +00:00
Molly Smith 042115c80b
Fix fused qkv sizing for bloom (#4161)
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2023-08-29 14:30:30 +00:00
Dino Chen 0712e29920
add meta onDevice support for LLAMA2 (#4147)
Co-authored-by: Molly Smith <112220543+molly-smith@users.noreply.github.com>
2023-08-24 23:22:14 +00:00