зеркало из https://github.com/microsoft/DeepSpeed.git
Support non-tensor state in checkpoint (#548)
This commit is contained in:
Родитель
0178e6cc22
Коммит
6021b70288
|
@ -947,9 +947,10 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
|
|||
state_key,
|
||||
all_partition_states,
|
||||
max_elems_per_comm):
|
||||
partition_id = dist.get_rank(group=self.dp_process_group)
|
||||
alignment = dist.get_world_size(group=self.dp_process_group)
|
||||
if not torch.is_tensor(all_partition_states[0]):
|
||||
return all_partition_states[0]
|
||||
|
||||
alignment = dist.get_world_size(group=self.dp_process_group)
|
||||
flat_merged_partitions = flatten_dense_tensors_sub_partition_aligned(
|
||||
tensor_list=all_partition_states,
|
||||
dp=dist.get_world_size(group=self.dp_process_group),
|
||||
|
@ -964,6 +965,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
|
|||
dp_process_group=self.dp_process_group
|
||||
)
|
||||
|
||||
partition_id = dist.get_rank(group=self.dp_process_group)
|
||||
return [sub_partition for sub_partition in dp_sub_partitions[partition_id]]
|
||||
|
||||
# Compute the optimizer state partitions for the group by
|
||||
|
@ -1013,8 +1015,11 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
|
|||
for group_idx, group in enumerate(self.optimizer.param_groups):
|
||||
for param_idx, param in enumerate(group['params']):
|
||||
for key, saved in base_optimizer_group_states[group_idx].items():
|
||||
current = self.optimizer.state[param][key]
|
||||
current.data.copy_(saved[param_idx].data)
|
||||
if torch.is_tensor(self.optimizer.state[param][key]):
|
||||
current = self.optimizer.state[param][key]
|
||||
current.data.copy_(saved[param_idx].data)
|
||||
else:
|
||||
self.optimizer.state[param][key] = saved
|
||||
|
||||
# Restore base optimizer fp32 weights from ZeRO fp16 weights
|
||||
def _restore_from_fp16_weights(self):
|
||||
|
|
|
@ -101,6 +101,37 @@ class SimpleOptimizer(torch.optim.Optimizer):
|
|||
return loss
|
||||
|
||||
|
||||
class HybridStateOptimizer(torch.optim.Optimizer):
|
||||
def __init__(self, params, lr=0.11072018):
|
||||
defaults = dict(lr=lr)
|
||||
super(HybridStateOptimizer, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(HybridStateOptimizer, self).__setstate__(state)
|
||||
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state['integer_step'] = 0
|
||||
state['tensor_step'] = torch.zeros(1)
|
||||
|
||||
d_p = p.grad.data
|
||||
p.data.add_(-group['lr'], d_p)
|
||||
state['integer_step'] += 1
|
||||
state['tensor_step'] += 1
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class PLD_SimpleModel(SimpleModel):
|
||||
def __init__(self, hidden_dim, empty_grad=False, rank=0):
|
||||
super(PLD_SimpleModel, self).__init__(hidden_dim, empty_grad, rank)
|
||||
|
|
|
@ -36,6 +36,7 @@ def compare_model_states(saved_model, loaded_model, compare_optimizer=True):
|
|||
compare_deepspeed_states(saved_model, loaded_model)
|
||||
|
||||
for p0, p1 in zip(saved_model.module.parameters(), loaded_model.module.parameters()):
|
||||
assert id(p0) != id(p1), f'Comparing fp16 model state tensor against itself : {id(p0)} <====> {id(p1)}'
|
||||
assert torch.allclose(p0, p1, atol=1e-07), f"FP16 model state {p0} is not equal to {p1}"
|
||||
|
||||
if not compare_optimizer:
|
||||
|
@ -43,20 +44,24 @@ def compare_model_states(saved_model, loaded_model, compare_optimizer=True):
|
|||
|
||||
if isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer):
|
||||
for p0, p1 in zip(saved_model.optimizer.single_partition_of_fp32_groups, loaded_model.optimizer.single_partition_of_fp32_groups):
|
||||
assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
|
||||
assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
|
||||
|
||||
elif isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer_Stage1):
|
||||
for partition0, partition1 in zip(saved_model.optimizer.local_sub_partitions_of_fp32_groups, loaded_model.optimizer.local_sub_partitions_of_fp32_groups):
|
||||
for p0, p1 in zip(partition0, partition1):
|
||||
assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
|
||||
assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
|
||||
|
||||
elif isinstance(saved_model.optimizer, FP16_Optimizer):
|
||||
for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat):
|
||||
assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
|
||||
assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
|
||||
|
||||
elif isinstance(saved_model.optimizer, FP16_UnfusedOptimizer):
|
||||
for params0, params1 in zip(saved_model.optimizer.fp32_groups, loaded_model.optimizer.fp32_groups):
|
||||
for p0, p1 in zip(params0, params1):
|
||||
assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
|
||||
assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
|
||||
elif isinstance(saved_model.optimizer, torch.optim.Optimizer):
|
||||
pass
|
||||
|
@ -72,6 +77,7 @@ def compare_optimizer_states(saved_model, loaded_model, hidden_dim, fp16=True):
|
|||
loaded_optimizer.state.values()):
|
||||
for s0, s1 in zip(state0.values(), state1.values()):
|
||||
if isinstance(s0, torch.Tensor) and isinstance(s1, torch.Tensor):
|
||||
assert id(s0) != id(s1), f'Comparing optimizer state tensor against itself: {id(s0)} <====> {id(s1)}'
|
||||
assert torch.equal(s0, s1)
|
||||
else:
|
||||
assert s0 == s1
|
||||
|
@ -100,18 +106,34 @@ def compare_lr_scheduler_states(saved_model, loaded_model):
|
|||
assert state0 == state1
|
||||
|
||||
|
||||
def create_deepspeed_model(args, model, base_optimizer):
|
||||
if base_optimizer is None:
|
||||
ds_model, _, _, _ = deepspeed.initialize(args=args,
|
||||
model=model,
|
||||
model_parameters=model.parameters())
|
||||
else:
|
||||
ds_model, _, _, _ = deepspeed.initialize(args=args,
|
||||
model=model,
|
||||
optimizer=base_optimizer)
|
||||
|
||||
return ds_model
|
||||
|
||||
|
||||
def checkpoint_correctness_verification(args,
|
||||
model,
|
||||
models,
|
||||
hidden_dim,
|
||||
tmpdir,
|
||||
load_optimizer_states=False,
|
||||
load_lr_scheduler_states=False,
|
||||
fp16=True,
|
||||
train_batch=False):
|
||||
train_batch=False,
|
||||
base_optimizers=[None,
|
||||
None]):
|
||||
dtype = torch.half if fp16 else torch.float32
|
||||
ds_model, _, _, _ = deepspeed.initialize(args=args,
|
||||
model=model,
|
||||
model_parameters=model.parameters())
|
||||
ds_model = create_deepspeed_model(args=args,
|
||||
model=models[0],
|
||||
base_optimizer=base_optimizers[0])
|
||||
|
||||
data_loader = random_dataloader(model=ds_model,
|
||||
total_samples=50,
|
||||
hidden_dim=hidden_dim,
|
||||
|
@ -125,7 +147,6 @@ def checkpoint_correctness_verification(args,
|
|||
else:
|
||||
for n, batch in enumerate(data_loader):
|
||||
loss = ds_model(batch[0], batch[1])
|
||||
print(loss)
|
||||
ds_model.backward(loss)
|
||||
ds_model.step()
|
||||
|
||||
|
@ -136,9 +157,9 @@ def checkpoint_correctness_verification(args,
|
|||
|
||||
trained_model.save_checkpoint(save_folder, save_tag)
|
||||
|
||||
loaded_model, _, _, _ = deepspeed.initialize(args=args,
|
||||
model=model,
|
||||
model_parameters=model.parameters())
|
||||
loaded_model = create_deepspeed_model(args=args,
|
||||
model=models[1],
|
||||
base_optimizer=base_optimizers[1])
|
||||
|
||||
loaded_model.load_checkpoint(save_folder,
|
||||
save_tag,
|
||||
|
@ -191,25 +212,26 @@ def test_checkpoint_unfused_optimizer(tmpdir):
|
|||
args = args_from_dict(tmpdir, config_dict)
|
||||
hidden_dim = 10
|
||||
|
||||
model = SimpleModel(hidden_dim, empty_grad=False)
|
||||
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
|
||||
|
||||
@distributed_test(world_size=[2])
|
||||
def _test_checkpoint_unfused_optimizer(args,
|
||||
model,
|
||||
models,
|
||||
hidden_dim,
|
||||
load_optimizer_states):
|
||||
checkpoint_correctness_verification(args,
|
||||
model,
|
||||
hidden_dim,
|
||||
tmpdir,
|
||||
models=models,
|
||||
hidden_dim=hidden_dim,
|
||||
tmpdir=tmpdir,
|
||||
load_optimizer_states=load_optimizer_states)
|
||||
|
||||
_test_checkpoint_unfused_optimizer(args=args,
|
||||
model=model,
|
||||
models=models,
|
||||
hidden_dim=hidden_dim,
|
||||
load_optimizer_states=True)
|
||||
|
||||
_test_checkpoint_unfused_optimizer(args=args,
|
||||
model=model,
|
||||
models=models,
|
||||
hidden_dim=hidden_dim,
|
||||
load_optimizer_states=False)
|
||||
|
||||
|
@ -236,22 +258,26 @@ def test_checkpoint_fused_optimizer(tmpdir):
|
|||
args = args_from_dict(tmpdir, config_dict)
|
||||
hidden_dim = 10
|
||||
|
||||
model = SimpleModel(hidden_dim, empty_grad=False)
|
||||
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
|
||||
|
||||
@distributed_test(world_size=[2])
|
||||
def _test_checkpoint_fused_optimizer(args, model, hidden_dim, load_optimizer_states):
|
||||
def _test_checkpoint_fused_optimizer(args,
|
||||
models,
|
||||
hidden_dim,
|
||||
load_optimizer_states):
|
||||
checkpoint_correctness_verification(args,
|
||||
model,
|
||||
hidden_dim,
|
||||
tmpdir,
|
||||
models=models,
|
||||
hidden_dim=hidden_dim,
|
||||
tmpdir=tmpdir,
|
||||
load_optimizer_states=load_optimizer_states)
|
||||
|
||||
_test_checkpoint_fused_optimizer(args=args,
|
||||
model=model,
|
||||
models=models,
|
||||
hidden_dim=hidden_dim,
|
||||
load_optimizer_states=True)
|
||||
|
||||
_test_checkpoint_fused_optimizer(args=args,
|
||||
model=model,
|
||||
models=models,
|
||||
hidden_dim=hidden_dim,
|
||||
load_optimizer_states=False)
|
||||
|
||||
|
@ -293,18 +319,18 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload):
|
|||
args = args_from_dict(tmpdir, config_dict)
|
||||
hidden_dim = 10
|
||||
|
||||
model = SimpleModel(hidden_dim, empty_grad=False)
|
||||
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
|
||||
|
||||
@distributed_test(world_size=[2])
|
||||
def _test_checkpoint_zero_optimizer(args, model, hidden_dim, load_optimizer_states):
|
||||
def _test_checkpoint_zero_optimizer(args, models, hidden_dim, load_optimizer_states):
|
||||
checkpoint_correctness_verification(args,
|
||||
model,
|
||||
hidden_dim,
|
||||
tmpdir,
|
||||
models=models,
|
||||
hidden_dim=hidden_dim,
|
||||
tmpdir=tmpdir,
|
||||
load_optimizer_states=load_optimizer_states)
|
||||
|
||||
_test_checkpoint_zero_optimizer(args=args,
|
||||
model=model,
|
||||
models=models,
|
||||
hidden_dim=hidden_dim,
|
||||
load_optimizer_states=True)
|
||||
|
||||
|
@ -346,21 +372,21 @@ def test_checkpoint_zero_no_optimizer(tmpdir, zero_stage, use_cpu_offload):
|
|||
args = args_from_dict(tmpdir, config_dict)
|
||||
hidden_dim = 10
|
||||
|
||||
model = SimpleModel(hidden_dim, empty_grad=False)
|
||||
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
|
||||
|
||||
@distributed_test(world_size=[2])
|
||||
def _test_checkpoint_zero_no_optimizer(args,
|
||||
model,
|
||||
models,
|
||||
hidden_dim,
|
||||
load_optimizer_states):
|
||||
checkpoint_correctness_verification(args,
|
||||
model,
|
||||
hidden_dim,
|
||||
tmpdir,
|
||||
models=models,
|
||||
hidden_dim=hidden_dim,
|
||||
tmpdir=tmpdir,
|
||||
load_optimizer_states=load_optimizer_states)
|
||||
|
||||
_test_checkpoint_zero_no_optimizer(args=args,
|
||||
model=model,
|
||||
models=models,
|
||||
hidden_dim=hidden_dim,
|
||||
load_optimizer_states=False)
|
||||
|
||||
|
@ -412,24 +438,24 @@ def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload):
|
|||
args = args_from_dict(tmpdir, config_dict)
|
||||
hidden_dim = 10
|
||||
|
||||
model = SimpleModel(hidden_dim, empty_grad=False)
|
||||
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
|
||||
|
||||
@distributed_test(world_size=[2])
|
||||
def _test_checkpoint_lr_scheduler(args,
|
||||
model,
|
||||
models,
|
||||
hidden_dim,
|
||||
load_optimizer_states,
|
||||
load_lr_scheduler_states):
|
||||
checkpoint_correctness_verification(
|
||||
args,
|
||||
model,
|
||||
hidden_dim,
|
||||
tmpdir,
|
||||
models=models,
|
||||
hidden_dim=hidden_dim,
|
||||
tmpdir=tmpdir,
|
||||
load_optimizer_states=load_optimizer_states,
|
||||
load_lr_scheduler_states=load_lr_scheduler_states)
|
||||
|
||||
_test_checkpoint_lr_scheduler(args=args,
|
||||
model=model,
|
||||
models=models,
|
||||
hidden_dim=hidden_dim,
|
||||
load_optimizer_states=False,
|
||||
load_lr_scheduler_states=True)
|
||||
|
@ -478,24 +504,24 @@ def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload):
|
|||
args = args_from_dict(tmpdir, config_dict)
|
||||
hidden_dim = 10
|
||||
|
||||
model = SimpleModel(hidden_dim, empty_grad=False)
|
||||
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
|
||||
|
||||
@distributed_test(world_size=[2])
|
||||
def _test_checkpoint_no_lr_scheduler(args,
|
||||
model,
|
||||
models,
|
||||
hidden_dim,
|
||||
load_optimizer_states,
|
||||
load_lr_scheduler_states):
|
||||
checkpoint_correctness_verification(
|
||||
args,
|
||||
model,
|
||||
hidden_dim,
|
||||
tmpdir,
|
||||
models=models,
|
||||
hidden_dim=hidden_dim,
|
||||
tmpdir=tmpdir,
|
||||
load_optimizer_states=load_optimizer_states,
|
||||
load_lr_scheduler_states=load_lr_scheduler_states)
|
||||
|
||||
_test_checkpoint_no_lr_scheduler(args=args,
|
||||
model=model,
|
||||
models=models,
|
||||
hidden_dim=hidden_dim,
|
||||
load_optimizer_states=False,
|
||||
load_lr_scheduler_states=False)
|
||||
|
@ -523,13 +549,17 @@ def test_checkpoint_fp32_optimizer(tmpdir):
|
|||
args = args_from_dict(tmpdir, config_dict)
|
||||
hidden_dim = 10
|
||||
|
||||
model = SimpleModel(hidden_dim, empty_grad=False)
|
||||
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
|
||||
|
||||
@distributed_test(world_size=[2])
|
||||
def _test_checkpoint_fp32_optimizer(args, model, hidden_dim):
|
||||
checkpoint_correctness_verification(args, model, hidden_dim, tmpdir, fp16=False)
|
||||
def _test_checkpoint_fp32_optimizer(args, models, hidden_dim):
|
||||
checkpoint_correctness_verification(args,
|
||||
models=models,
|
||||
hidden_dim=hidden_dim,
|
||||
tmpdir=tmpdir,
|
||||
fp16=False)
|
||||
|
||||
_test_checkpoint_fp32_optimizer(args=args, model=model, hidden_dim=hidden_dim)
|
||||
_test_checkpoint_fp32_optimizer(args=args, models=models, hidden_dim=hidden_dim)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("zero_stage", [0, 1])
|
||||
|
@ -571,10 +601,10 @@ def test_checkpoint_pipe_engine(zero_stage, tmpdir, stages=2):
|
|||
@distributed_test(world_size=4)
|
||||
def _test(save_folder, num_stages):
|
||||
args = args_from_dict(tmpdir, config_dict)
|
||||
model = LinearStackPipe(num_stages=num_stages)
|
||||
models = [LinearStackPipe(num_stages=num_stages) for _ in range(2)]
|
||||
checkpoint_correctness_verification(args=args,
|
||||
model=model,
|
||||
hidden_dim=model.hidden_dim,
|
||||
models=models,
|
||||
hidden_dim=models[0].hidden_dim,
|
||||
tmpdir=save_folder,
|
||||
fp16=config_dict['fp16']['enabled'],
|
||||
load_optimizer_states=True,
|
||||
|
@ -635,3 +665,42 @@ def test_checkpoint_pipe_module(base_topo, test_topo, tmpdir):
|
|||
assert torch.allclose(p0, p1, atol=1e-07), f"Model state {p0} is not equal to {p1}"
|
||||
|
||||
_test(base_topo, test_topo, save_folder=tmpdir)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('zero_stage', [1, 2])
|
||||
def test_checkpoint_zero_hybrid_optimizer_state(tmpdir, zero_stage):
|
||||
config_dict = {
|
||||
"train_micro_batch_size_per_gpu": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"steps_per_print": 1,
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage
|
||||
},
|
||||
"zero_allow_untested_optimizer": True,
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
}
|
||||
}
|
||||
|
||||
args = args_from_dict(tmpdir, config_dict)
|
||||
hidden_dim = 10
|
||||
models = [SimpleModel(hidden_dim=hidden_dim) for _ in range(2)]
|
||||
optimizers = [HybridStateOptimizer(model.parameters()) for model in models]
|
||||
|
||||
@distributed_test(world_size=[2])
|
||||
def _test_checkpoint_zero_hybrid_optimizer_state(args,
|
||||
models,
|
||||
optimizers,
|
||||
hidden_dim):
|
||||
checkpoint_correctness_verification(args,
|
||||
models=models,
|
||||
base_optimizers=optimizers,
|
||||
hidden_dim=hidden_dim,
|
||||
tmpdir=tmpdir,
|
||||
load_optimizer_states=True)
|
||||
|
||||
_test_checkpoint_zero_hybrid_optimizer_state(args=args,
|
||||
models=models,
|
||||
optimizers=optimizers,
|
||||
hidden_dim=hidden_dim)
|
||||
|
|
Загрузка…
Ссылка в новой задаче