Precisely track nvme optimizer offload (#6963)

Fix #4998
This commit is contained in:
Olatunji Ruwase 2025-01-23 11:42:06 -05:00 коммит произвёл GitHub
Родитель de4596bedc
Коммит 470dd6dceb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
4 изменённых файлов: 16 добавлений и 15 удалений

Просмотреть файл

@ -799,10 +799,8 @@ class DeepSpeedEngine(Module):
def zero_elastic_checkpoint(self): def zero_elastic_checkpoint(self):
return self._config.zero_config.elastic_checkpoint return self._config.zero_config.elastic_checkpoint
def zero_has_nvme_offload(self): def zero_nvme_offload_optimizer(self):
if not hasattr(self.optimizer, "swap_optimizer"): return getattr(self.optimizer, "swap_optimizer", False)
return False
return self.optimizer.swap_optimizer or self.optimizer.params_in_nvme_and_cpu
def zero_max_live_parameters(self): def zero_max_live_parameters(self):
return self._config.zero_config.max_live_parameters return self._config.zero_config.max_live_parameters
@ -2865,7 +2863,7 @@ class DeepSpeedEngine(Module):
if not success: if not success:
self.optimizer._restore_from_bit16_weights() self.optimizer._restore_from_bit16_weights()
if self.zero_has_nvme_offload(): if self.zero_nvme_offload_optimizer():
from shutil import copytree, disk_usage from shutil import copytree, disk_usage
offload_dir = self.optimizer.optimizer_swapper.swap_folder offload_dir = self.optimizer.optimizer_swapper.swap_folder
offload_ckpt_dir = os.path.join(load_dir, tag, "offloaded_tensors") offload_ckpt_dir = os.path.join(load_dir, tag, "offloaded_tensors")
@ -3205,7 +3203,7 @@ class DeepSpeedEngine(Module):
self._create_zero_checkpoint_files(save_dir, tag) self._create_zero_checkpoint_files(save_dir, tag)
self._save_zero_checkpoint(save_dir, tag) self._save_zero_checkpoint(save_dir, tag)
if self.zero_has_nvme_offload(): if self.zero_nvme_offload_optimizer():
from shutil import copytree, disk_usage from shutil import copytree, disk_usage
offload_dir = self.optimizer.optimizer_swapper.swap_folder offload_dir = self.optimizer.optimizer_swapper.swap_folder
offload_ckpt_dir = os.path.join(save_dir, tag, "offloaded_tensors") offload_ckpt_dir = os.path.join(save_dir, tag, "offloaded_tensors")

Просмотреть файл

@ -153,6 +153,11 @@ class OptimizerSwapper(object):
'timer_names', 'timer_names',
] ]
def purge_state(self):
for swap_info in self.swap_params_info.values():
swap_info.tensors = [swap_info.tensors[0]]
swap_info.has_state_tensors = False
def swappable_tensor(self, param=None, numel=None): def swappable_tensor(self, param=None, numel=None):
assert param is not None or numel is not None, "Either param or numel must be provided" assert param is not None or numel is not None, "Either param or numel must be provided"
if param is not None: if param is not None:

Просмотреть файл

@ -2652,11 +2652,9 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT]) self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT])
self._clear_fp32_optimizer_param_groups() self._clear_fp32_optimizer_param_groups()
if self.swap_optimizer or self.params_in_nvme_and_cpu: if self.swap_optimizer:
# Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint # Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
for swap_info in self.optimizer_swapper.swap_params_info.values(): self.optimizer_swapper.purge_state()
swap_info.tensors = [swap_info.tensors[0]]
swap_info.has_state_tensors = False
if self.swap_optimizer: if self.swap_optimizer:
# Touch all parameters to synchronize all buffers # Touch all parameters to synchronize all buffers
@ -2773,11 +2771,9 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
else: else:
optim_sd[OPTIMIZER_STATE_DICT]['state'][0][key] = key_tensor optim_sd[OPTIMIZER_STATE_DICT]['state'][0][key] = key_tensor
if self.swap_optimizer or self.params_in_nvme_and_cpu: if self.swap_optimizer:
# Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint # Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
for swap_info in self.optimizer_swapper.swap_params_info.values(): self.optimizer_swapper.purge_state()
swap_info.tensors = [swap_info.tensors[0]]
swap_info.has_state_tensors = False
if self.swap_optimizer: if self.swap_optimizer:
# Touch all parameters to synchronize all buffers # Touch all parameters to synchronize all buffers

Просмотреть файл

@ -22,8 +22,10 @@ class TestNVMeCheckpointing(DistributedTest):
world_size = 1 world_size = 1
@pytest.mark.parametrize('param_offload_device, optim_offload_device', @pytest.mark.parametrize('param_offload_device, optim_offload_device',
[(OffloadDeviceEnum.cpu, OffloadDeviceEnum.cpu), [(OffloadDeviceEnum.none, OffloadDeviceEnum.nvme),
(OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme), (OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme),
(OffloadDeviceEnum.nvme, OffloadDeviceEnum.none),
(OffloadDeviceEnum.nvme, OffloadDeviceEnum.cpu),
(OffloadDeviceEnum.nvme, OffloadDeviceEnum.nvme)]) (OffloadDeviceEnum.nvme, OffloadDeviceEnum.nvme)])
def test_nvme_checkpointing(self, tmpdir, param_offload_device, optim_offload_device): def test_nvme_checkpointing(self, tmpdir, param_offload_device, optim_offload_device):
zero_dir, ckpt_dir = os.path.join(tmpdir, "zero"), os.path.join(tmpdir, "checkpoint") zero_dir, ckpt_dir = os.path.join(tmpdir, "zero"), os.path.join(tmpdir, "checkpoint")