bf16_optimizer: fixes to different grad acc dtype (#6485)

- fix step function to cast to FP32 before step in case of different
gradient accumulation data type
- remove redundatn function initialize_optimizer_states()
This commit is contained in:
Nadav Elyahu 2024-09-04 18:27:26 +03:00 коммит произвёл GitHub
Родитель 9b7fc54524
Коммит cfc6ed3722
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
1 изменённых файлов: 10 добавлений и 23 удалений

Просмотреть файл

@ -197,10 +197,6 @@ class BF16_Optimizer(ZeROOptimizer):
see_memory_usage(f'after initializing group {i}', force=True)
see_memory_usage('before initialize_optimizer', force=True)
self.initialize_optimizer_states()
see_memory_usage('end initialize_optimizer', force=True)
self._grad_acc_hooks = []
if self.immediate_grad_update:
self.create_grad_acc_hooks()
@ -252,25 +248,6 @@ class BF16_Optimizer(ZeROOptimizer):
self.optimizer.state)
self._hp_optimizer_states_linked = True
def initialize_optimizer_states(self):
"""Take an optimizer step with zero-valued gradients to allocate internal
optimizer state.
This helps prevent memory fragmentation by allocating optimizer state at the
beginning of training instead of after activations have been allocated.
"""
for param_partition, grad_partition in zip(self.fp32_groups_flat_partition,
self.fp32_groups_gradient_flat_partition):
# In case of grad acc dtype different than FP32, need to cast to high precision.
param_partition.grad = grad_partition.to(
param_partition.dtype) if grad_partition.dtype != param_partition.dtype else grad_partition
if self.grad_acc_dtype is not torch.float32:
for param_partition in self.fp32_groups_flat_partition:
param_partition.grad = None
self.clear_hp_grads()
def _split_flat_tensor(self, flat_tensor, num_elem_list):
assert sum(num_elem_list) <= flat_tensor.numel()
tensor_list = []
@ -317,8 +294,18 @@ class BF16_Optimizer(ZeROOptimizer):
mpu=self.mpu,
use_graph=self.graph_harvesting)
for param_partition, grad_partition in zip(self.fp32_groups_flat_partition,
self.fp32_groups_gradient_flat_partition):
# In case of grad acc dtype different than FP32, need to cast to high precision.
param_partition.grad = grad_partition.to(
param_partition.dtype) if grad_partition.dtype != param_partition.dtype else grad_partition
self.optimizer.step()
if self.grad_acc_dtype is not torch.float32:
for param_partition in self.fp32_groups_flat_partition:
param_partition.grad = None
# We need to link optimizer state after the first step() call
self._lazy_init_hp_params_optimizer_state()