зеркало из https://github.com/microsoft/DeepSpeed.git
zero_to_fp32.py: Handle a case where shape doesn't have numel attr (#4842)
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
Родитель
ac84cf3ff1
Коммит
691458f8b6
|
@ -248,6 +248,11 @@ def _zero2_merge_frozen_params(state_dict, zero_model_states):
|
|||
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
|
||||
|
||||
|
||||
def _has_callable(obj, fn):
|
||||
attr = getattr(obj, fn, None)
|
||||
return callable(attr)
|
||||
|
||||
|
||||
def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
|
||||
param_shapes = zero_model_states[0].param_shapes
|
||||
|
||||
|
@ -287,7 +292,7 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
|
|||
avail_numel = full_single_fp32_vector.numel()
|
||||
for name, shape in shapes.items():
|
||||
|
||||
unpartitioned_numel = shape.numel()
|
||||
unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
|
||||
total_numel += unpartitioned_numel
|
||||
total_params += 1
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче