From 6eed634eda502300b702f7a80c23f24aea08ed29 Mon Sep 17 00:00:00 2001 From: Nadav Elyahu <88962733+nelyahu@users.noreply.github.com> Date: Thu, 15 Aug 2024 02:38:45 +0300 Subject: [PATCH] Z3: optimizations for grad norm calculation and gradient clipping (#5504) This PR add the below functionality: 1. complete_grad_norm_calculation_for_cpu_offload: move total_norm to CPU, as expected device in such case is CPU.. 2. repalce get_global_norm() with torch.linalg.norm for better performance. 3. unscale_and_clip_grads: replace clipping based on if statement to use torch.clamp for better performance. change (3) is taken from https://github.com/microsoft/DeepSpeed/pull/5547 (which was closed) --------- Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com> Co-authored-by: Liran Bachar --- deepspeed/runtime/zero/stage3.py | 10 +++++----- tests/unit/runtime/zero/test_zero_offloadpp.py | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 9b7645261..b0a3ab778 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -15,7 +15,7 @@ from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.utils import logger from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce -from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter, get_only_unique_item +from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum @@ -1413,7 +1413,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer): err = torch.tensor(-1.0, device=inf_or_nan.device, dtype=torch.float) total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm - return total_norm + return total_norm.cpu() @instrument_w_nvtx def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None: @@ -2028,7 +2028,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer): return norm_groups = self._get_norm_groups() - scaled_global_grad_norm = get_global_norm(norm_list=norm_groups) + scaled_global_grad_norm = torch.linalg.norm(torch.stack(norm_groups)) # Stash unscaled gradient norm self._global_grad_norm = scaled_global_grad_norm / self.loss_scale @@ -2112,8 +2112,8 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer): if self.clip_grad > 0.: # norm is in fact norm*scale clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad - if clip > 1: - combined_scale = clip * self.loss_scale + clip = torch.clamp(clip, min=1.0) + combined_scale = clip * self.loss_scale self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale) diff --git a/tests/unit/runtime/zero/test_zero_offloadpp.py b/tests/unit/runtime/zero/test_zero_offloadpp.py index 5bfec399e..8ae99e223 100644 --- a/tests/unit/runtime/zero/test_zero_offloadpp.py +++ b/tests/unit/runtime/zero/test_zero_offloadpp.py @@ -43,6 +43,7 @@ class TestZeroPartialOffloadConfigSweep(DistributedTest): config_dict = { "train_batch_size": 256, "steps_per_print": 1, + "gradient_clipping": 1.0, "optimizer": { "type": "Adam", "params": {