зеркало из https://github.com/microsoft/DeepSpeed.git
Родитель
de4596bedc
Коммит
470dd6dceb
|
@ -799,10 +799,8 @@ class DeepSpeedEngine(Module):
|
||||||
def zero_elastic_checkpoint(self):
|
def zero_elastic_checkpoint(self):
|
||||||
return self._config.zero_config.elastic_checkpoint
|
return self._config.zero_config.elastic_checkpoint
|
||||||
|
|
||||||
def zero_has_nvme_offload(self):
|
def zero_nvme_offload_optimizer(self):
|
||||||
if not hasattr(self.optimizer, "swap_optimizer"):
|
return getattr(self.optimizer, "swap_optimizer", False)
|
||||||
return False
|
|
||||||
return self.optimizer.swap_optimizer or self.optimizer.params_in_nvme_and_cpu
|
|
||||||
|
|
||||||
def zero_max_live_parameters(self):
|
def zero_max_live_parameters(self):
|
||||||
return self._config.zero_config.max_live_parameters
|
return self._config.zero_config.max_live_parameters
|
||||||
|
@ -2865,7 +2863,7 @@ class DeepSpeedEngine(Module):
|
||||||
if not success:
|
if not success:
|
||||||
self.optimizer._restore_from_bit16_weights()
|
self.optimizer._restore_from_bit16_weights()
|
||||||
|
|
||||||
if self.zero_has_nvme_offload():
|
if self.zero_nvme_offload_optimizer():
|
||||||
from shutil import copytree, disk_usage
|
from shutil import copytree, disk_usage
|
||||||
offload_dir = self.optimizer.optimizer_swapper.swap_folder
|
offload_dir = self.optimizer.optimizer_swapper.swap_folder
|
||||||
offload_ckpt_dir = os.path.join(load_dir, tag, "offloaded_tensors")
|
offload_ckpt_dir = os.path.join(load_dir, tag, "offloaded_tensors")
|
||||||
|
@ -3205,7 +3203,7 @@ class DeepSpeedEngine(Module):
|
||||||
self._create_zero_checkpoint_files(save_dir, tag)
|
self._create_zero_checkpoint_files(save_dir, tag)
|
||||||
self._save_zero_checkpoint(save_dir, tag)
|
self._save_zero_checkpoint(save_dir, tag)
|
||||||
|
|
||||||
if self.zero_has_nvme_offload():
|
if self.zero_nvme_offload_optimizer():
|
||||||
from shutil import copytree, disk_usage
|
from shutil import copytree, disk_usage
|
||||||
offload_dir = self.optimizer.optimizer_swapper.swap_folder
|
offload_dir = self.optimizer.optimizer_swapper.swap_folder
|
||||||
offload_ckpt_dir = os.path.join(save_dir, tag, "offloaded_tensors")
|
offload_ckpt_dir = os.path.join(save_dir, tag, "offloaded_tensors")
|
||||||
|
|
|
@ -153,6 +153,11 @@ class OptimizerSwapper(object):
|
||||||
'timer_names',
|
'timer_names',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def purge_state(self):
|
||||||
|
for swap_info in self.swap_params_info.values():
|
||||||
|
swap_info.tensors = [swap_info.tensors[0]]
|
||||||
|
swap_info.has_state_tensors = False
|
||||||
|
|
||||||
def swappable_tensor(self, param=None, numel=None):
|
def swappable_tensor(self, param=None, numel=None):
|
||||||
assert param is not None or numel is not None, "Either param or numel must be provided"
|
assert param is not None or numel is not None, "Either param or numel must be provided"
|
||||||
if param is not None:
|
if param is not None:
|
||||||
|
|
|
@ -2652,11 +2652,9 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||||
self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT])
|
self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT])
|
||||||
self._clear_fp32_optimizer_param_groups()
|
self._clear_fp32_optimizer_param_groups()
|
||||||
|
|
||||||
if self.swap_optimizer or self.params_in_nvme_and_cpu:
|
if self.swap_optimizer:
|
||||||
# Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
|
# Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
|
||||||
for swap_info in self.optimizer_swapper.swap_params_info.values():
|
self.optimizer_swapper.purge_state()
|
||||||
swap_info.tensors = [swap_info.tensors[0]]
|
|
||||||
swap_info.has_state_tensors = False
|
|
||||||
|
|
||||||
if self.swap_optimizer:
|
if self.swap_optimizer:
|
||||||
# Touch all parameters to synchronize all buffers
|
# Touch all parameters to synchronize all buffers
|
||||||
|
@ -2773,11 +2771,9 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||||
else:
|
else:
|
||||||
optim_sd[OPTIMIZER_STATE_DICT]['state'][0][key] = key_tensor
|
optim_sd[OPTIMIZER_STATE_DICT]['state'][0][key] = key_tensor
|
||||||
|
|
||||||
if self.swap_optimizer or self.params_in_nvme_and_cpu:
|
if self.swap_optimizer:
|
||||||
# Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
|
# Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
|
||||||
for swap_info in self.optimizer_swapper.swap_params_info.values():
|
self.optimizer_swapper.purge_state()
|
||||||
swap_info.tensors = [swap_info.tensors[0]]
|
|
||||||
swap_info.has_state_tensors = False
|
|
||||||
|
|
||||||
if self.swap_optimizer:
|
if self.swap_optimizer:
|
||||||
# Touch all parameters to synchronize all buffers
|
# Touch all parameters to synchronize all buffers
|
||||||
|
|
|
@ -22,8 +22,10 @@ class TestNVMeCheckpointing(DistributedTest):
|
||||||
world_size = 1
|
world_size = 1
|
||||||
|
|
||||||
@pytest.mark.parametrize('param_offload_device, optim_offload_device',
|
@pytest.mark.parametrize('param_offload_device, optim_offload_device',
|
||||||
[(OffloadDeviceEnum.cpu, OffloadDeviceEnum.cpu),
|
[(OffloadDeviceEnum.none, OffloadDeviceEnum.nvme),
|
||||||
(OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme),
|
(OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme),
|
||||||
|
(OffloadDeviceEnum.nvme, OffloadDeviceEnum.none),
|
||||||
|
(OffloadDeviceEnum.nvme, OffloadDeviceEnum.cpu),
|
||||||
(OffloadDeviceEnum.nvme, OffloadDeviceEnum.nvme)])
|
(OffloadDeviceEnum.nvme, OffloadDeviceEnum.nvme)])
|
||||||
def test_nvme_checkpointing(self, tmpdir, param_offload_device, optim_offload_device):
|
def test_nvme_checkpointing(self, tmpdir, param_offload_device, optim_offload_device):
|
||||||
zero_dir, ckpt_dir = os.path.join(tmpdir, "zero"), os.path.join(tmpdir, "checkpoint")
|
zero_dir, ckpt_dir = os.path.join(tmpdir, "zero"), os.path.join(tmpdir, "checkpoint")
|
||||||
|
|
Загрузка…
Ссылка в новой задаче