Support non-tensor state in checkpoint (#548)

This commit is contained in:
Olatunji Ruwase 2020-11-21 15:41:22 -08:00 коммит произвёл GitHub
Родитель 0178e6cc22
Коммит 6021b70288
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 163 добавлений и 58 удалений

Просмотреть файл

@ -947,9 +947,10 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
state_key,
all_partition_states,
max_elems_per_comm):
partition_id = dist.get_rank(group=self.dp_process_group)
alignment = dist.get_world_size(group=self.dp_process_group)
if not torch.is_tensor(all_partition_states[0]):
return all_partition_states[0]
alignment = dist.get_world_size(group=self.dp_process_group)
flat_merged_partitions = flatten_dense_tensors_sub_partition_aligned(
tensor_list=all_partition_states,
dp=dist.get_world_size(group=self.dp_process_group),
@ -964,6 +965,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
dp_process_group=self.dp_process_group
)
partition_id = dist.get_rank(group=self.dp_process_group)
return [sub_partition for sub_partition in dp_sub_partitions[partition_id]]
# Compute the optimizer state partitions for the group by
@ -1013,8 +1015,11 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
for group_idx, group in enumerate(self.optimizer.param_groups):
for param_idx, param in enumerate(group['params']):
for key, saved in base_optimizer_group_states[group_idx].items():
current = self.optimizer.state[param][key]
current.data.copy_(saved[param_idx].data)
if torch.is_tensor(self.optimizer.state[param][key]):
current = self.optimizer.state[param][key]
current.data.copy_(saved[param_idx].data)
else:
self.optimizer.state[param][key] = saved
# Restore base optimizer fp32 weights from ZeRO fp16 weights
def _restore_from_fp16_weights(self):

Просмотреть файл

@ -101,6 +101,37 @@ class SimpleOptimizer(torch.optim.Optimizer):
return loss
class HybridStateOptimizer(torch.optim.Optimizer):
def __init__(self, params, lr=0.11072018):
defaults = dict(lr=lr)
super(HybridStateOptimizer, self).__init__(params, defaults)
def __setstate__(self, state):
super(HybridStateOptimizer, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
if len(state) == 0:
state['integer_step'] = 0
state['tensor_step'] = torch.zeros(1)
d_p = p.grad.data
p.data.add_(-group['lr'], d_p)
state['integer_step'] += 1
state['tensor_step'] += 1
return loss
class PLD_SimpleModel(SimpleModel):
def __init__(self, hidden_dim, empty_grad=False, rank=0):
super(PLD_SimpleModel, self).__init__(hidden_dim, empty_grad, rank)

Просмотреть файл

@ -36,6 +36,7 @@ def compare_model_states(saved_model, loaded_model, compare_optimizer=True):
compare_deepspeed_states(saved_model, loaded_model)
for p0, p1 in zip(saved_model.module.parameters(), loaded_model.module.parameters()):
assert id(p0) != id(p1), f'Comparing fp16 model state tensor against itself : {id(p0)} <====> {id(p1)}'
assert torch.allclose(p0, p1, atol=1e-07), f"FP16 model state {p0} is not equal to {p1}"
if not compare_optimizer:
@ -43,20 +44,24 @@ def compare_model_states(saved_model, loaded_model, compare_optimizer=True):
if isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer):
for p0, p1 in zip(saved_model.optimizer.single_partition_of_fp32_groups, loaded_model.optimizer.single_partition_of_fp32_groups):
assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer_Stage1):
for partition0, partition1 in zip(saved_model.optimizer.local_sub_partitions_of_fp32_groups, loaded_model.optimizer.local_sub_partitions_of_fp32_groups):
for p0, p1 in zip(partition0, partition1):
assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, FP16_Optimizer):
for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat):
assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, FP16_UnfusedOptimizer):
for params0, params1 in zip(saved_model.optimizer.fp32_groups, loaded_model.optimizer.fp32_groups):
for p0, p1 in zip(params0, params1):
assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, torch.optim.Optimizer):
pass
@ -72,6 +77,7 @@ def compare_optimizer_states(saved_model, loaded_model, hidden_dim, fp16=True):
loaded_optimizer.state.values()):
for s0, s1 in zip(state0.values(), state1.values()):
if isinstance(s0, torch.Tensor) and isinstance(s1, torch.Tensor):
assert id(s0) != id(s1), f'Comparing optimizer state tensor against itself: {id(s0)} <====> {id(s1)}'
assert torch.equal(s0, s1)
else:
assert s0 == s1
@ -100,18 +106,34 @@ def compare_lr_scheduler_states(saved_model, loaded_model):
assert state0 == state1
def create_deepspeed_model(args, model, base_optimizer):
if base_optimizer is None:
ds_model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
else:
ds_model, _, _, _ = deepspeed.initialize(args=args,
model=model,
optimizer=base_optimizer)
return ds_model
def checkpoint_correctness_verification(args,
model,
models,
hidden_dim,
tmpdir,
load_optimizer_states=False,
load_lr_scheduler_states=False,
fp16=True,
train_batch=False):
train_batch=False,
base_optimizers=[None,
None]):
dtype = torch.half if fp16 else torch.float32
ds_model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
ds_model = create_deepspeed_model(args=args,
model=models[0],
base_optimizer=base_optimizers[0])
data_loader = random_dataloader(model=ds_model,
total_samples=50,
hidden_dim=hidden_dim,
@ -125,7 +147,6 @@ def checkpoint_correctness_verification(args,
else:
for n, batch in enumerate(data_loader):
loss = ds_model(batch[0], batch[1])
print(loss)
ds_model.backward(loss)
ds_model.step()
@ -136,9 +157,9 @@ def checkpoint_correctness_verification(args,
trained_model.save_checkpoint(save_folder, save_tag)
loaded_model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
loaded_model = create_deepspeed_model(args=args,
model=models[1],
base_optimizer=base_optimizers[1])
loaded_model.load_checkpoint(save_folder,
save_tag,
@ -191,25 +212,26 @@ def test_checkpoint_unfused_optimizer(tmpdir):
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
@distributed_test(world_size=[2])
def _test_checkpoint_unfused_optimizer(args,
model,
models,
hidden_dim,
load_optimizer_states):
checkpoint_correctness_verification(args,
model,
hidden_dim,
tmpdir,
models=models,
hidden_dim=hidden_dim,
tmpdir=tmpdir,
load_optimizer_states=load_optimizer_states)
_test_checkpoint_unfused_optimizer(args=args,
model=model,
models=models,
hidden_dim=hidden_dim,
load_optimizer_states=True)
_test_checkpoint_unfused_optimizer(args=args,
model=model,
models=models,
hidden_dim=hidden_dim,
load_optimizer_states=False)
@ -236,22 +258,26 @@ def test_checkpoint_fused_optimizer(tmpdir):
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
@distributed_test(world_size=[2])
def _test_checkpoint_fused_optimizer(args, model, hidden_dim, load_optimizer_states):
def _test_checkpoint_fused_optimizer(args,
models,
hidden_dim,
load_optimizer_states):
checkpoint_correctness_verification(args,
model,
hidden_dim,
tmpdir,
models=models,
hidden_dim=hidden_dim,
tmpdir=tmpdir,
load_optimizer_states=load_optimizer_states)
_test_checkpoint_fused_optimizer(args=args,
model=model,
models=models,
hidden_dim=hidden_dim,
load_optimizer_states=True)
_test_checkpoint_fused_optimizer(args=args,
model=model,
models=models,
hidden_dim=hidden_dim,
load_optimizer_states=False)
@ -293,18 +319,18 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload):
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
@distributed_test(world_size=[2])
def _test_checkpoint_zero_optimizer(args, model, hidden_dim, load_optimizer_states):
def _test_checkpoint_zero_optimizer(args, models, hidden_dim, load_optimizer_states):
checkpoint_correctness_verification(args,
model,
hidden_dim,
tmpdir,
models=models,
hidden_dim=hidden_dim,
tmpdir=tmpdir,
load_optimizer_states=load_optimizer_states)
_test_checkpoint_zero_optimizer(args=args,
model=model,
models=models,
hidden_dim=hidden_dim,
load_optimizer_states=True)
@ -346,21 +372,21 @@ def test_checkpoint_zero_no_optimizer(tmpdir, zero_stage, use_cpu_offload):
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
@distributed_test(world_size=[2])
def _test_checkpoint_zero_no_optimizer(args,
model,
models,
hidden_dim,
load_optimizer_states):
checkpoint_correctness_verification(args,
model,
hidden_dim,
tmpdir,
models=models,
hidden_dim=hidden_dim,
tmpdir=tmpdir,
load_optimizer_states=load_optimizer_states)
_test_checkpoint_zero_no_optimizer(args=args,
model=model,
models=models,
hidden_dim=hidden_dim,
load_optimizer_states=False)
@ -412,24 +438,24 @@ def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload):
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
@distributed_test(world_size=[2])
def _test_checkpoint_lr_scheduler(args,
model,
models,
hidden_dim,
load_optimizer_states,
load_lr_scheduler_states):
checkpoint_correctness_verification(
args,
model,
hidden_dim,
tmpdir,
models=models,
hidden_dim=hidden_dim,
tmpdir=tmpdir,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states)
_test_checkpoint_lr_scheduler(args=args,
model=model,
models=models,
hidden_dim=hidden_dim,
load_optimizer_states=False,
load_lr_scheduler_states=True)
@ -478,24 +504,24 @@ def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload):
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
@distributed_test(world_size=[2])
def _test_checkpoint_no_lr_scheduler(args,
model,
models,
hidden_dim,
load_optimizer_states,
load_lr_scheduler_states):
checkpoint_correctness_verification(
args,
model,
hidden_dim,
tmpdir,
models=models,
hidden_dim=hidden_dim,
tmpdir=tmpdir,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states)
_test_checkpoint_no_lr_scheduler(args=args,
model=model,
models=models,
hidden_dim=hidden_dim,
load_optimizer_states=False,
load_lr_scheduler_states=False)
@ -523,13 +549,17 @@ def test_checkpoint_fp32_optimizer(tmpdir):
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
@distributed_test(world_size=[2])
def _test_checkpoint_fp32_optimizer(args, model, hidden_dim):
checkpoint_correctness_verification(args, model, hidden_dim, tmpdir, fp16=False)
def _test_checkpoint_fp32_optimizer(args, models, hidden_dim):
checkpoint_correctness_verification(args,
models=models,
hidden_dim=hidden_dim,
tmpdir=tmpdir,
fp16=False)
_test_checkpoint_fp32_optimizer(args=args, model=model, hidden_dim=hidden_dim)
_test_checkpoint_fp32_optimizer(args=args, models=models, hidden_dim=hidden_dim)
@pytest.mark.parametrize("zero_stage", [0, 1])
@ -571,10 +601,10 @@ def test_checkpoint_pipe_engine(zero_stage, tmpdir, stages=2):
@distributed_test(world_size=4)
def _test(save_folder, num_stages):
args = args_from_dict(tmpdir, config_dict)
model = LinearStackPipe(num_stages=num_stages)
models = [LinearStackPipe(num_stages=num_stages) for _ in range(2)]
checkpoint_correctness_verification(args=args,
model=model,
hidden_dim=model.hidden_dim,
models=models,
hidden_dim=models[0].hidden_dim,
tmpdir=save_folder,
fp16=config_dict['fp16']['enabled'],
load_optimizer_states=True,
@ -635,3 +665,42 @@ def test_checkpoint_pipe_module(base_topo, test_topo, tmpdir):
assert torch.allclose(p0, p1, atol=1e-07), f"Model state {p0} is not equal to {p1}"
_test(base_topo, test_topo, save_folder=tmpdir)
@pytest.mark.parametrize('zero_stage', [1, 2])
def test_checkpoint_zero_hybrid_optimizer_state(tmpdir, zero_stage):
config_dict = {
"train_micro_batch_size_per_gpu": 2,
"gradient_accumulation_steps": 2,
"steps_per_print": 1,
"zero_optimization": {
"stage": zero_stage
},
"zero_allow_untested_optimizer": True,
"fp16": {
"enabled": True,
"initial_scale_power": 8
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
models = [SimpleModel(hidden_dim=hidden_dim) for _ in range(2)]
optimizers = [HybridStateOptimizer(model.parameters()) for model in models]
@distributed_test(world_size=[2])
def _test_checkpoint_zero_hybrid_optimizer_state(args,
models,
optimizers,
hidden_dim):
checkpoint_correctness_verification(args,
models=models,
base_optimizers=optimizers,
hidden_dim=hidden_dim,
tmpdir=tmpdir,
load_optimizer_states=True)
_test_checkpoint_zero_hybrid_optimizer_state(args=args,
models=models,
optimizers=optimizers,
hidden_dim=hidden_dim)