Fix training of pipeline based peft's lora model (#5477)

Hi, guys

I find there is an assert failure when I train huggingface's lora based
model in pipeline style.

Here is the whole steps that I created my model:
1)  Load the pre-trained chatglm-6b model from huggingface, as Model_A
2) Use huggingface's peft's `get_peft_model(...)` and my
`LoraConfig(...)` from Model_A to create the lora model, as Model_B
3)  Create my own pipeline based model Model_C from Model_B

And I run Model_C under 2 3090ti GPUs. And the assertion failure looks
like this:
```text
Traceback (most recent call last):
  File "/home/ubuntu/proj/chatglm-finetuning/train_pipeline.py", line 372, in <module>
    main()
  File "/home/ubuntu/proj/chatglm-finetuning/train_pipeline.py", line 351, in main
    loss = engine.train_batch(data_iter=train_dataloader)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/runtime/pipe/engine.py", line 375, in train_batch
    self._exec_schedule(sched)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/runtime/pipe/engine.py", line 1375, in _exec_schedule
    self._exec_instr(**cmd.kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/runtime/pipe/engine.py", line 276, in _exec_reduce_tied_grads
    dist.all_reduce(grad, group=group)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/comm/comm.py", line 117, in log_wrapper
    return func(*args, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/comm/comm.py", line 496, in all_reduce
    return cdb.all_reduce(tensor, op, group, async_op)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/comm/torch.py", line 159, in all_reduce
    return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 1520, in all_reduce
    _check_single_tensor(tensor, "tensor")
  File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 463, in _check_single_tensor
    raise RuntimeError(
RuntimeError: Invalid function argument. Expected parameter `tensor` to be of type torch.Tensor.
```

After some debugging, I find out the root cause is that my configuration
of lora (in below) only add extra lora layer(part) in qkv related layers
but not the embedding layer. So the whole embedding layer's parameters
are freezed.
```python
lora_config = LoraConfig(r=8, # copied from finetuning_lora.py
                        lora_alpha=32,
                        target_modules=["query_key_value"],
                        lora_dropout=0.1,
                        bias="none",
                        task_type="CAUSAL_LM",
                        inference_mode=False,
                        )   
```
And in my implementation of pipeline based model, I declared the
embeding layer as a tied-layer. So the whole thing is that there are no
gradients at all for embedding layer, but embedding layer as the tied
layer needs to be synced between two gpus. The value of gradient is None
but is still passed to `all_reduce` operation.

Current, my fix is simple and add a check if this `grad` is None.

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Heyang Qin <heyangqin@microsoft.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
This commit is contained in:
xuanhua 2024-10-30 00:04:35 +08:00 коммит произвёл GitHub
Родитель 07cac9e021
Коммит e4a247ed13
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
1 изменённых файлов: 2 добавлений и 1 удалений

Просмотреть файл

@ -287,7 +287,8 @@ class PipelineEngine(DeepSpeedEngine):
weight_group_list = self.module.get_tied_weights_and_groups()
for weight, group in weight_group_list:
grad = weight._hp_grad if self.using_bf16_optimizer else weight.grad
dist.all_reduce(grad, group=group)
if grad is not None:
dist.all_reduce(grad, group=group)
def _exec_reduce_grads(self):
self._force_grad_boundary = True