zero_to_fp32.py: Handle a case where shape doesn't have numel attr (#4842)

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
Nadav Elyahu 2024-01-03 21:45:06 +02:00 коммит произвёл GitHub
Родитель ac84cf3ff1
Коммит 691458f8b6
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 6 добавлений и 1 удалений

Просмотреть файл

@ -248,6 +248,11 @@ def _zero2_merge_frozen_params(state_dict, zero_model_states):
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
def _has_callable(obj, fn):
attr = getattr(obj, fn, None)
return callable(attr)
def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
param_shapes = zero_model_states[0].param_shapes
@ -287,7 +292,7 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
avail_numel = full_single_fp32_vector.numel()
for name, shape in shapes.items():
unpartitioned_numel = shape.numel()
unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
total_numel += unpartitioned_numel
total_params += 1