Z3: optimizations for grad norm calculation and gradient clipping (#5504)

This PR add the below functionality:
1. complete_grad_norm_calculation_for_cpu_offload: move total_norm to
CPU, as expected device in such case is CPU..
2. repalce get_global_norm() with torch.linalg.norm for better
performance.
3. unscale_and_clip_grads: replace clipping based on if statement to use
torch.clamp for better performance.

change (3) is taken from
https://github.com/microsoft/DeepSpeed/pull/5547 (which was closed)

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>
Co-authored-by: Liran Bachar <lbachar@habana.ai>
This commit is contained in:
Nadav Elyahu 2024-08-15 02:38:45 +03:00 коммит произвёл GitHub
Родитель 19b01e1d60
Коммит 6eed634eda
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
2 изменённых файлов: 6 добавлений и 5 удалений

Просмотреть файл

@ -15,7 +15,7 @@ from deepspeed.runtime.base_optimizer import ZeROOptimizer
from deepspeed.utils import logger
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce
from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter, get_only_unique_item
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
@ -1413,7 +1413,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
err = torch.tensor(-1.0, device=inf_or_nan.device, dtype=torch.float)
total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm
return total_norm
return total_norm.cpu()
@instrument_w_nvtx
def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None:
@ -2028,7 +2028,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
return
norm_groups = self._get_norm_groups()
scaled_global_grad_norm = get_global_norm(norm_list=norm_groups)
scaled_global_grad_norm = torch.linalg.norm(torch.stack(norm_groups))
# Stash unscaled gradient norm
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale
@ -2112,7 +2112,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
if clip > 1:
clip = torch.clamp(clip, min=1.0)
combined_scale = clip * self.loss_scale
self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale)

Просмотреть файл

@ -43,6 +43,7 @@ class TestZeroPartialOffloadConfigSweep(DistributedTest):
config_dict = {
"train_batch_size": 256,
"steps_per_print": 1,
"gradient_clipping": 1.0,
"optimizer": {
"type": "Adam",
"params": {