зеркало из https://github.com/microsoft/DeepSpeed.git
bf16_optimizer: fixes to different grad acc dtype (#6485)
- fix step function to cast to FP32 before step in case of different gradient accumulation data type - remove redundatn function initialize_optimizer_states()
This commit is contained in:
Родитель
9b7fc54524
Коммит
cfc6ed3722
|
@ -197,10 +197,6 @@ class BF16_Optimizer(ZeROOptimizer):
|
|||
|
||||
see_memory_usage(f'after initializing group {i}', force=True)
|
||||
|
||||
see_memory_usage('before initialize_optimizer', force=True)
|
||||
self.initialize_optimizer_states()
|
||||
see_memory_usage('end initialize_optimizer', force=True)
|
||||
|
||||
self._grad_acc_hooks = []
|
||||
if self.immediate_grad_update:
|
||||
self.create_grad_acc_hooks()
|
||||
|
@ -252,25 +248,6 @@ class BF16_Optimizer(ZeROOptimizer):
|
|||
self.optimizer.state)
|
||||
self._hp_optimizer_states_linked = True
|
||||
|
||||
def initialize_optimizer_states(self):
|
||||
"""Take an optimizer step with zero-valued gradients to allocate internal
|
||||
optimizer state.
|
||||
|
||||
This helps prevent memory fragmentation by allocating optimizer state at the
|
||||
beginning of training instead of after activations have been allocated.
|
||||
"""
|
||||
for param_partition, grad_partition in zip(self.fp32_groups_flat_partition,
|
||||
self.fp32_groups_gradient_flat_partition):
|
||||
# In case of grad acc dtype different than FP32, need to cast to high precision.
|
||||
param_partition.grad = grad_partition.to(
|
||||
param_partition.dtype) if grad_partition.dtype != param_partition.dtype else grad_partition
|
||||
|
||||
if self.grad_acc_dtype is not torch.float32:
|
||||
for param_partition in self.fp32_groups_flat_partition:
|
||||
param_partition.grad = None
|
||||
|
||||
self.clear_hp_grads()
|
||||
|
||||
def _split_flat_tensor(self, flat_tensor, num_elem_list):
|
||||
assert sum(num_elem_list) <= flat_tensor.numel()
|
||||
tensor_list = []
|
||||
|
@ -317,8 +294,18 @@ class BF16_Optimizer(ZeROOptimizer):
|
|||
mpu=self.mpu,
|
||||
use_graph=self.graph_harvesting)
|
||||
|
||||
for param_partition, grad_partition in zip(self.fp32_groups_flat_partition,
|
||||
self.fp32_groups_gradient_flat_partition):
|
||||
# In case of grad acc dtype different than FP32, need to cast to high precision.
|
||||
param_partition.grad = grad_partition.to(
|
||||
param_partition.dtype) if grad_partition.dtype != param_partition.dtype else grad_partition
|
||||
|
||||
self.optimizer.step()
|
||||
|
||||
if self.grad_acc_dtype is not torch.float32:
|
||||
for param_partition in self.fp32_groups_flat_partition:
|
||||
param_partition.grad = None
|
||||
|
||||
# We need to link optimizer state after the first step() call
|
||||
self._lazy_init_hp_params_optimizer_state()
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче