зеркало из https://github.com/microsoft/DeepSpeed.git
Fix sort of zero checkpoint files (#5342)
The conversion from a regular checkpoint to universal one relies on sorting of zero checkpoint files to merge sharded optimizer states. This merge can silently produce wrong results as the sorting is in alphabetical order. The merging logic assumes that files are given in this order. 1. pp_index=0 tp_index=0 dp_index=0 2. pp_index=0 tp_index=0 dp_index=1 ... The optimizer state of a parameter can be sharded across multiple ranks. If it is sharded across dp_index 9-11, the files will be - bf16_zero_pp_rank_9_mp_rank_00_optim_states.pt - bf16_zero_pp_rank_10_mp_rank_00_optim_states.pt - bf16_zero_pp_rank_11_mp_rank_00_optim_states.pt As they are sorted in alphabetical order, the script merges the sharded fragment in the order of [10, 11, 9]. This PR fixes this sort to extracts dp ranks in files and sort the files treating the ranks as numbers. Fix #5283 Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
Родитель
40009eb1c7
Коммит
c946a34220
|
@ -4,9 +4,10 @@
|
|||
# DeepSpeed Team
|
||||
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from .constants import (ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX)
|
||||
from .constants import (ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX, MODEL_FILE_PREFIX)
|
||||
|
||||
|
||||
def basic_folder_validation(dir):
|
||||
|
@ -38,12 +39,28 @@ def get_files(dir):
|
|||
return file_list
|
||||
|
||||
|
||||
def sort_zero_files(files, prefix):
|
||||
pattern = f"{prefix}([0-9]+)_{MODEL_FILE_PREFIX}([0-9]+)"
|
||||
rank_pairs = []
|
||||
for f in files:
|
||||
m = re.search(pattern, f)
|
||||
if m:
|
||||
dp_rank = int(m.group(1))
|
||||
mp_rank = int(m.group(2))
|
||||
rank_pairs.append((dp_rank, mp_rank, f))
|
||||
else:
|
||||
raise ValueError(f"Cannot parse dp_rank and mp_rank from {f}")
|
||||
|
||||
sorted_files = sorted(rank_pairs, key=lambda x: (x[0], x[1]))
|
||||
return [f for _, _, f in sorted_files]
|
||||
|
||||
|
||||
def get_zero_files(dir):
|
||||
file_list = get_files(dir)
|
||||
for prefix in [ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX]:
|
||||
zero_files = get_files_with_prefix(file_list, prefix)
|
||||
if len(zero_files) > 0:
|
||||
return zero_files
|
||||
return sort_zero_files(zero_files, prefix)
|
||||
|
||||
return []
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче