Fix sort of zero checkpoint files (#5342)

The conversion from a regular checkpoint to universal one relies on
sorting of zero checkpoint files to merge sharded optimizer states. This
merge can silently produce wrong results as the sorting is in
alphabetical order.

The merging logic assumes that files are given in this order.
1. pp_index=0 tp_index=0 dp_index=0
2. pp_index=0 tp_index=0 dp_index=1
...

The optimizer state of a parameter can be sharded across multiple ranks.
If it is sharded across dp_index 9-11, the files will be
- bf16_zero_pp_rank_9_mp_rank_00_optim_states.pt
- bf16_zero_pp_rank_10_mp_rank_00_optim_states.pt
- bf16_zero_pp_rank_11_mp_rank_00_optim_states.pt 
 
As they are sorted in alphabetical order, the script merges the sharded
fragment in the order of [10, 11, 9].
This PR fixes this sort to extracts dp ranks in files and sort the files
treating the ranks as numbers.

Fix #5283

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
Masahiro Tanaka 2024-04-01 16:32:48 -07:00 коммит произвёл GitHub
Родитель 40009eb1c7
Коммит c946a34220
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
1 изменённых файлов: 19 добавлений и 2 удалений

Просмотреть файл

@ -4,9 +4,10 @@
# DeepSpeed Team
import os
import re
import torch
from collections import OrderedDict
from .constants import (ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX)
from .constants import (ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX, MODEL_FILE_PREFIX)
def basic_folder_validation(dir):
@ -38,12 +39,28 @@ def get_files(dir):
return file_list
def sort_zero_files(files, prefix):
pattern = f"{prefix}([0-9]+)_{MODEL_FILE_PREFIX}([0-9]+)"
rank_pairs = []
for f in files:
m = re.search(pattern, f)
if m:
dp_rank = int(m.group(1))
mp_rank = int(m.group(2))
rank_pairs.append((dp_rank, mp_rank, f))
else:
raise ValueError(f"Cannot parse dp_rank and mp_rank from {f}")
sorted_files = sorted(rank_pairs, key=lambda x: (x[0], x[1]))
return [f for _, _, f in sorted_files]
def get_zero_files(dir):
file_list = get_files(dir)
for prefix in [ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX]:
zero_files = get_files_with_prefix(file_list, prefix)
if len(zero_files) > 0:
return zero_files
return sort_zero_files(zero_files, prefix)
return []