Add explicit parameters for torch.load (#6751)

Successor PR to #6094:

> FutureWarning: You are using torch.load with weights_only=False (the
current default value), which uses the default pickle module implicitly.
It is possible to construct malicious pickle data which will execute
arbitrary code during unpickling (See
https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models
for more details). In a future release, the default value for
weights_only will be flipped to True. This limits the functions that
could be executed during unpickling. Arbitrary objects will no longer be
allowed to be loaded via this mode unless they are explicitly
allowlisted by the user via torch.serialization.add_safe_globals. We
recommend you start setting weights_only=True for any use case where you
don't have full control of the loaded file. Please open an issue on
GitHub for any issues related to this experimental feature.

Todo:
- [ ] Update values in non-test files to True where necessary.
This commit is contained in:
Logan Adams 2024-11-19 11:09:52 -08:00 коммит произвёл GitHub
Родитель 1fdad1fa52
Коммит 2e0c39b55c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
18 изменённых файлов: 44 добавлений и 40 удалений

Просмотреть файл

@ -116,7 +116,7 @@ class DeepSpeedCheckpoint(object):
self._dump_mapping(self.transformer_file_map, 'rank_to_transformer_files') self._dump_mapping(self.transformer_file_map, 'rank_to_transformer_files')
def _build_global_state(self): def _build_global_state(self):
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False)
self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)
self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None) self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None)
@ -137,14 +137,17 @@ class DeepSpeedCheckpoint(object):
def get_iteration(self): def get_iteration(self):
if not ITERATION_KEY in self.global_state: if not ITERATION_KEY in self.global_state:
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False)
self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)
return self.global_state[ITERATION_KEY] return self.global_state[ITERATION_KEY]
def get_embedding_state(self, tp_index: int) -> Dict: def get_embedding_state(self, tp_index: int) -> Dict:
assert tp_index in self.tp_to_embedding_map.keys() assert tp_index in self.tp_to_embedding_map.keys()
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in self.tp_to_embedding_map[tp_index]] sd_list = [
torch.load(fname, map_location=torch.device('cpu'), weights_only=False)
for fname in self.tp_to_embedding_map[tp_index]
]
sd = self._merge_state_dicts(sd_list) sd = self._merge_state_dicts(sd_list)
return sd return sd
@ -154,7 +157,7 @@ class DeepSpeedCheckpoint(object):
def _get_checkpoint_value(self, key): def _get_checkpoint_value(self, key):
if not key in self.global_state: if not key in self.global_state:
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False)
self.global_state[key] = sd.get(key, None) self.global_state[key] = sd.get(key, None)
return self.global_state[key] return self.global_state[key]
@ -169,7 +172,7 @@ class DeepSpeedCheckpoint(object):
assert tp_index < self.tp_degree assert tp_index < self.tp_degree
assert pp_index < self.pp_degree assert pp_index < self.pp_degree
fname_list = self.get_2d_parallel_files(tp_index=tp_index, pp_index=pp_index) fname_list = self.get_2d_parallel_files(tp_index=tp_index, pp_index=pp_index)
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list] sd_list = [torch.load(fname, map_location=torch.device('cpu'), weights_only=False) for fname in fname_list]
merged_sd = None merged_sd = None
for sd in sd_list: for sd in sd_list:
@ -185,7 +188,7 @@ class DeepSpeedCheckpoint(object):
assert pp_index < self.pp_degree assert pp_index < self.pp_degree
t_list = [] t_list = []
for fname_list in self.transformer_file_map[(tp_index, pp_index)]: for fname_list in self.transformer_file_map[(tp_index, pp_index)]:
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list] sd_list = [torch.load(fname, map_location=torch.device('cpu'), weights_only=False) for fname in fname_list]
sd = self._merge_state_dicts(sd_list) sd = self._merge_state_dicts(sd_list)
t_list.append(sd) t_list.append(sd)
return t_list return t_list
@ -196,7 +199,7 @@ class DeepSpeedCheckpoint(object):
def get_final_norm_state(self, tp_index: int) -> Dict: def get_final_norm_state(self, tp_index: int) -> Dict:
assert tp_index in self.tp_to_final_norm_map.keys() assert tp_index in self.tp_to_final_norm_map.keys()
sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu')) sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu'), weights_only=False)
return sd return sd
def get_final_norm_files(self, tp_index: int) -> list: def get_final_norm_files(self, tp_index: int) -> list:

Просмотреть файл

@ -150,7 +150,7 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D):
def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, dp_index): def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, dp_index):
state_dict = torch.load(optim_files[dp_index], map_location='cpu') state_dict = torch.load(optim_files[dp_index], map_location='cpu', weights_only=False)
flat_state = dict( flat_state = dict(
exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg"], exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg"],
@ -214,7 +214,7 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape=None):
raise ValueError(f"Cannot parse dp_rank from {p}") raise ValueError(f"Cannot parse dp_rank from {p}")
paths = [f"{prefix_path}.{dp_index_to_str(dp_index)}" for dp_index in sorted(list(dp_indices))] paths = [f"{prefix_path}.{dp_index_to_str(dp_index)}" for dp_index in sorted(list(dp_indices))]
shards = [torch.load(p) for p in paths] shards = [torch.load(p, weights_only=False) for p in paths]
if state == "step": if state == "step":
assert all(v == shards[0] for v in shards), "All shards must have the same step value" assert all(v == shards[0] for v in shards), "All shards must have the same step value"
@ -404,7 +404,7 @@ def _zero_partitioned_param_info(unpartitioned_numel, world_size):
def _parse_model_states_stage3(files): def _parse_model_states_stage3(files):
return torch.load(files[0], map_location=torch.device('cpu'))[PARAM_SHAPES] return torch.load(files[0], map_location=torch.device('cpu'), weights_only=False)[PARAM_SHAPES]
def _save_optimizer_state(args, ds_checkpoint): def _save_optimizer_state(args, ds_checkpoint):
@ -420,7 +420,7 @@ def _save_optimizer_state(args, ds_checkpoint):
def _save_optimizer_state_stage3(args, optim_files): def _save_optimizer_state_stage3(args, optim_files):
sd = torch.load(optim_files[0], map_location=torch.device('cpu')) sd = torch.load(optim_files[0], map_location=torch.device('cpu'), weights_only=False)
output_sd = sd[OPTIMIZER_STATE_DICT] output_sd = sd[OPTIMIZER_STATE_DICT]
output_sd[PARAM_GROUPS] = output_sd[OPTIMIZER_STATE_DICT][PARAM_GROUPS] output_sd[PARAM_GROUPS] = output_sd[OPTIMIZER_STATE_DICT][PARAM_GROUPS]
zero_output_folder = os.path.join(args.output_folder, "zero") zero_output_folder = os.path.join(args.output_folder, "zero")
@ -446,7 +446,7 @@ def _get_checkpoint_files(checkpoint_dir, glob_pattern):
def _get_zero_stage(optim_files): def _get_zero_stage(optim_files):
state_dict = torch.load(optim_files[0], map_location=torch.device('cpu')) state_dict = torch.load(optim_files[0], map_location=torch.device('cpu'), weights_only=False)
optimizer_state = state_dict[OPTIMIZER_STATE_DICT] optimizer_state = state_dict[OPTIMIZER_STATE_DICT]
zero_stage = optimizer_state.get(ZERO_STAGE, 1) zero_stage = optimizer_state.get(ZERO_STAGE, 1)
return zero_stage return zero_stage
@ -454,7 +454,7 @@ def _get_zero_stage(optim_files):
def _inject_missing_state(ds_checkpoint): def _inject_missing_state(ds_checkpoint):
if UNIVERSAL_CHECKPOINT_INFO not in ds_checkpoint.global_state: if UNIVERSAL_CHECKPOINT_INFO not in ds_checkpoint.global_state:
sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu')) sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False)
if UNIVERSAL_CHECKPOINT_INFO not in sd: if UNIVERSAL_CHECKPOINT_INFO not in sd:
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO] = {} ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO] = {}
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO][ ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO][
@ -488,7 +488,7 @@ def main(args):
slice_shapes = [] slice_shapes = []
for mp_rank_file in ds_checkpoint.mp_rank_files: for mp_rank_file in ds_checkpoint.mp_rank_files:
mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu')) mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu'), weights_only=False)
slice_shapes += mp_sd[PARAM_SHAPES] slice_shapes += mp_sd[PARAM_SHAPES]
# fix back to normal flat dict, merge duplicates for tp>1 # fix back to normal flat dict, merge duplicates for tp>1

Просмотреть файл

@ -34,7 +34,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
step = None step = None
for key in hp_keys: for key in hp_keys:
ckpt_file = os.path.join(folder, f"{key}.pt") ckpt_file = os.path.join(folder, f"{key}.pt")
ckpt_dict = torch.load(ckpt_file) ckpt_dict = torch.load(ckpt_file, weights_only=False)
if key == "step": if key == "step":
step = ckpt_dict step = ckpt_dict

Просмотреть файл

@ -54,7 +54,7 @@ class ZeROCheckpoint(object):
state_file_list = self.get_files_for_rank(pp_index, tp_index, dp_index) state_file_list = self.get_files_for_rank(pp_index, tp_index, dp_index)
merged_sd = None merged_sd = None
for state_file in state_file_list: for state_file in state_file_list:
sd = torch.load(state_file, map_location=torch.device('cpu')) sd = torch.load(state_file, map_location=torch.device('cpu'), weights_only=False)
for key in keys_to_ignore: for key in keys_to_ignore:
sd.pop(key, None) sd.pop(key, None)

Просмотреть файл

@ -452,7 +452,7 @@ class InferenceEngine(Module):
checkpoint = sd_loader['checkpoints'] checkpoint = sd_loader['checkpoints']
if type(checkpoint) is list: if type(checkpoint) is list:
self.sd = torch.load(checkpoint[0], map_location='cpu') self.sd = torch.load(checkpoint[0], map_location='cpu', weights_only=False)
self.key_list = list(self.sd.keys()) self.key_list = list(self.sd.keys())
self.load_model_with_checkpoint(self.module) self.load_model_with_checkpoint(self.module)
@ -460,7 +460,7 @@ class InferenceEngine(Module):
for i in range(1, len(checkpoint)): for i in range(1, len(checkpoint)):
if not dist.is_initialized() or dist.get_rank() == 0: if not dist.is_initialized() or dist.get_rank() == 0:
print(f"loading checkpoint ({i})") print(f"loading checkpoint ({i})")
self.sd = torch.load(checkpoint[i], map_location=get_accelerator().device_name()) self.sd = torch.load(checkpoint[i], map_location=get_accelerator().device_name(), weights_only=False)
self.key_list = list(self.sd.keys()) self.key_list = list(self.sd.keys())
self.load_model_with_checkpoint(self.module) self.load_model_with_checkpoint(self.module)
else: else:

Просмотреть файл

@ -80,7 +80,7 @@ class HuggingFaceCheckpointEngine(CheckpointEngineBase):
else: else:
model_param_json_fname = "pytorch_model.bin.index.json" model_param_json_fname = "pytorch_model.bin.index.json"
model_file_fname = "pytorch_model.bin" model_file_fname = "pytorch_model.bin"
self._checkpoint_load_fn = partial(torch.load, map_location="cpu") self._checkpoint_load_fn = partial(torch.load, map_location="cpu", weights_only=False)
model_param_json = os.path.join(self._local_checkpoint_dir, model_param_json_fname) model_param_json = os.path.join(self._local_checkpoint_dir, model_param_json_fname)

Просмотреть файл

@ -205,7 +205,7 @@ class InferenceV2Policy(ABC, metaclass=PolicyMeta):
buffer_path = make_param_filename(self._inf_checkpoint_path, self.model.tp_rank, self.model.tp_size) buffer_path = make_param_filename(self._inf_checkpoint_path, self.model.tp_rank, self.model.tp_size)
metadata_path = make_metadata_filename(self._inf_checkpoint_path, self.model.tp_rank, self.model.tp_size) metadata_path = make_metadata_filename(self._inf_checkpoint_path, self.model.tp_rank, self.model.tp_size)
buffer = torch.load(buffer_path) buffer = torch.load(buffer_path, weights_only=False)
metadata = json.load(open(metadata_path, "r")) metadata = json.load(open(metadata_path, "r"))
metadata = ModelMetadata.parse_raw(metadata) metadata = ModelMetadata.parse_raw(metadata)

Просмотреть файл

@ -415,7 +415,7 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m
pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards") pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards")
for i in range(len(checkpoint)): for i in range(len(checkpoint)):
sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu')] sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu', weights_only=False)]
load_model_with_checkpoint(replaced_module, load_model_with_checkpoint(replaced_module,
sd, sd,
mp_replace, mp_replace,
@ -437,7 +437,7 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m
os.path.join(base_dir1, ckpt_list[ckpt_index + j]) if base_dir1 else ckpt_list[ckpt_index + j] os.path.join(base_dir1, ckpt_list[ckpt_index + j]) if base_dir1 else ckpt_list[ckpt_index + j]
for j in range(sd_count) for j in range(sd_count)
] ]
sds = [torch.load(ckpt_file, map_location='cpu') for ckpt_file in ckpt_files] sds = [torch.load(ckpt_file, map_location='cpu', weights_only=False) for ckpt_file in ckpt_files]
load_model_with_checkpoint(replaced_module, load_model_with_checkpoint(replaced_module,
sds, sds,
mp_replace, mp_replace,
@ -457,7 +457,7 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m
pbar.update(1) pbar.update(1)
ckpt_file = os.path.join(base_dir1, ckpt_file = os.path.join(base_dir1,
checkpoint["non_tp"][i]) if base_dir1 else checkpoint["non_tp"][i] checkpoint["non_tp"][i]) if base_dir1 else checkpoint["non_tp"][i]
sds = [torch.load(ckpt_file, map_location='cpu')] sds = [torch.load(ckpt_file, map_location='cpu', weights_only=False)]
load_model_with_checkpoint(replaced_module, load_model_with_checkpoint(replaced_module,
sds, sds,
mp_replace, mp_replace,
@ -624,7 +624,7 @@ def replace_module(model, orig_class, replace_fn, _replace_policy, checkpoint=No
from safetensors.torch import load_file from safetensors.torch import load_file
sd = load_file(checkpoint) sd = load_file(checkpoint)
else: else:
sd = torch.load(checkpoint, map_location='cpu') sd = torch.load(checkpoint, map_location='cpu', weights_only=False)
policy = {} policy = {}
if orig_class is not None: if orig_class is not None:

Просмотреть файл

@ -22,7 +22,7 @@ class ZeROOptimizer(DeepSpeedOptimizer):
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
assert os.path.isfile( assert os.path.isfile(
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'
optim_sd = torch.load(optim_state_path) optim_sd = torch.load(optim_state_path, weights_only=False)
self._load_global_state(optim_sd) self._load_global_state(optim_sd)

Просмотреть файл

@ -58,7 +58,7 @@ class NebulaCheckpointEngine(CheckpointEngine):
if not self.enable_nebula_load and first_load_flag: if not self.enable_nebula_load and first_load_flag:
self.tag_flag = tag self.tag_flag = tag
logger.info(f"[Nebula] Disable nebula load. Loading checkpoint from {path} ...") logger.info(f"[Nebula] Disable nebula load. Loading checkpoint from {path} ...")
partition = torch.load(path, map_location=map_location) partition = torch.load(path, map_location=map_location, weights_only=False)
logger.info(f"[Nebula] Disable nebula load. Loaded checkpoint from {path} .") logger.info(f"[Nebula] Disable nebula load. Loaded checkpoint from {path} .")
return partition return partition

Просмотреть файл

@ -25,7 +25,7 @@ class TorchCheckpointEngine(CheckpointEngine):
def load(self, path: str, map_location=None): def load(self, path: str, map_location=None):
logger.info(f"[Torch] Loading checkpoint from {path}...") logger.info(f"[Torch] Loading checkpoint from {path}...")
partition = torch.load(path, map_location=map_location) partition = torch.load(path, map_location=map_location, weights_only=False)
logger.info(f"[Torch] Loaded checkpoint from {path}.") logger.info(f"[Torch] Loaded checkpoint from {path}.")
return partition return partition

Просмотреть файл

@ -2741,7 +2741,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
assert os.path.isfile( assert os.path.isfile(
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'
optim_sd = torch.load(optim_state_path) optim_sd = torch.load(optim_state_path, weights_only=False)
self._load_global_state_stage3(optim_sd) self._load_global_state_stage3(optim_sd)
key_list = ["fp32", "exp_avg", "exp_avg_sq"] key_list = ["fp32", "exp_avg", "exp_avg_sq"]
@ -2799,7 +2799,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
local_rank = dist.get_local_rank() local_rank = dist.get_local_rank()
# Load tensors from files and reshape them to flat vectors # Load tensors from files and reshape them to flat vectors
loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt")).view(-1) loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False).view(-1)
# Partition the loaded data according to the local rank # Partition the loaded data according to the local rank
world_size = dist.get_world_size(group=self.dp_process_group) world_size = dist.get_world_size(group=self.dp_process_group)

Просмотреть файл

@ -102,7 +102,7 @@ def get_model_state_files(checkpoint_dir):
def parse_model_states(files): def parse_model_states(files):
zero_model_states = [] zero_model_states = []
for file in files: for file in files:
state_dict = torch.load(file, map_location=device) state_dict = torch.load(file, map_location=device, weights_only=False)
if BUFFER_NAMES not in state_dict: if BUFFER_NAMES not in state_dict:
raise ValueError(f"{file} is not a model state checkpoint") raise ValueError(f"{file} is not a model state checkpoint")
@ -149,7 +149,7 @@ def parse_optim_states(files, ds_checkpoint_dir):
total_files = len(files) total_files = len(files)
state_dicts = [] state_dicts = []
for f in tqdm(files, desc='Loading checkpoint shards'): for f in tqdm(files, desc='Loading checkpoint shards'):
state_dict = torch.load(f, map_location=device, mmap=True) state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
# and also handle the case where it was already removed by another helper script # and also handle the case where it was already removed by another helper script
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)

Просмотреть файл

@ -218,7 +218,7 @@ def checkpoint_correctness_verification(config_dict,
for root, _, files in os.walk(save_folder): for root, _, files in os.walk(save_folder):
for f in files: for f in files:
if "_expert_" in f and "_model_states" in f: if "_expert_" in f and "_model_states" in f:
expert = torch.load(os.path.join(root, f)) expert = torch.load(os.path.join(root, f), weights_only=False)
needed, storages = 0, {} needed, storages = 0, {}
for name, tensor in expert.items(): for name, tensor in expert.items():
needed += tensor.size().numel() needed += tensor.size().numel()

Просмотреть файл

@ -181,7 +181,7 @@ class TestZeROUniversalCheckpointDP(DistributedTest):
) )
hidden_dim = 10 hidden_dim = 10
loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt") loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt", weights_only=False)
ds_config["checkpoint"] = {"load_universal": True} ds_config["checkpoint"] = {"load_universal": True}
univ_model = SimpleModel(hidden_dim) univ_model = SimpleModel(hidden_dim)

Просмотреть файл

@ -264,7 +264,7 @@ class TestZeROElasticCheckpoint(DistributedTest):
model.load_checkpoint(tmpdir, load_optimizer_states=load_optim) model.load_checkpoint(tmpdir, load_optimizer_states=load_optim)
if load_optim: if load_optim:
saved_sd = torch.load(os.path.join(tmpdir, opt_state_dict_file)) saved_sd = torch.load(os.path.join(tmpdir, opt_state_dict_file), weights_only=False)
curr_sd = model.optimizer.optimizer.state_dict() curr_sd = model.optimizer.optimizer.state_dict()
compare_opt_state_dicts(curr_sd, saved_sd, expected_mismatch_keys) compare_opt_state_dicts(curr_sd, saved_sd, expected_mismatch_keys)
@ -523,7 +523,7 @@ class TestZeROCheckpointFrozenWeights(DistributedTest):
all_ckpt_folder = os.path.join(tmpdir, 'all_params') all_ckpt_folder = os.path.join(tmpdir, 'all_params')
ds_engine.save_checkpoint(all_ckpt_folder) ds_engine.save_checkpoint(all_ckpt_folder)
all_params_ckpt_file = get_model_ckpt_name_for_rank(os.path.join(all_ckpt_folder, 'global_step0'), '00') all_params_ckpt_file = get_model_ckpt_name_for_rank(os.path.join(all_ckpt_folder, 'global_step0'), '00')
loaded_all_param_model = torch.load(all_params_ckpt_file)['module'] loaded_all_param_model = torch.load(all_params_ckpt_file, weights_only=False)['module']
all_param_names = set([n for n, p in model.named_parameters()]) all_param_names = set([n for n, p in model.named_parameters()])
assert set(loaded_all_param_model.keys()) == all_param_names assert set(loaded_all_param_model.keys()) == all_param_names
@ -536,7 +536,7 @@ class TestZeROCheckpointFrozenWeights(DistributedTest):
# Excluding frozen parameters should reduce checkpoint size # Excluding frozen parameters should reduce checkpoint size
assert os.path.getsize(all_params_ckpt_file) > os.path.getsize(trainable_ckpt_file) assert os.path.getsize(all_params_ckpt_file) > os.path.getsize(trainable_ckpt_file)
loaded_trainable_param_model = torch.load(trainable_ckpt_file)['module'] loaded_trainable_param_model = torch.load(trainable_ckpt_file, weights_only=False)['module']
frozen_param_names = set([n for n, p in model.named_parameters() if not p.requires_grad]) frozen_param_names = set([n for n, p in model.named_parameters() if not p.requires_grad])
loaded_trainable_param_names = set(loaded_trainable_param_model.keys()) loaded_trainable_param_names = set(loaded_trainable_param_model.keys())
overlap_names = set.intersection(loaded_trainable_param_names, frozen_param_names) overlap_names = set.intersection(loaded_trainable_param_names, frozen_param_names)
@ -575,7 +575,7 @@ class TestZeROCheckpointFrozenWeights(DistributedTest):
custom_state_dict_ckpt_file = get_model_ckpt_name_for_rank( custom_state_dict_ckpt_file = get_model_ckpt_name_for_rank(
os.path.join(custom_state_dict_ckpt_folder, 'global_step0'), '00') os.path.join(custom_state_dict_ckpt_folder, 'global_step0'), '00')
loaded_custom_state_dict_param_model = torch.load(custom_state_dict_ckpt_file)['module'] loaded_custom_state_dict_param_model = torch.load(custom_state_dict_ckpt_file, weights_only=False)['module']
loaded_custom_state_dict_param_names = set(loaded_custom_state_dict_param_model.keys()) loaded_custom_state_dict_param_names = set(loaded_custom_state_dict_param_model.keys())
custom_state_dict_param_names = set([k for k, v in model.state_dict().items()]) custom_state_dict_param_names = set([k for k, v in model.state_dict().items()])
@ -618,7 +618,8 @@ class TestSaveTensorClone(DistributedTest):
clone_ckpt_file = os.path.join(tmpdir, 'clone_ckpt.pt') clone_ckpt_file = os.path.join(tmpdir, 'clone_ckpt.pt')
torch.save(clone_state_dict, clone_ckpt_file) torch.save(clone_state_dict, clone_ckpt_file)
compare_state_dicts(torch.load(ref_ckpt_file), torch.load(clone_ckpt_file)) compare_state_dicts(torch.load(ref_ckpt_file, weights_only=False),
torch.load(clone_ckpt_file, weights_only=False))
class TestZeRONonDistributed(DistributedTest): class TestZeRONonDistributed(DistributedTest):

Просмотреть файл

@ -170,7 +170,7 @@ class TestConfigurableResizeMP(ConfigurableMP):
test = model(inputs[0].to(device_name), inputs[1].to(device_name), inputs[2].to(device_name)) test = model(inputs[0].to(device_name), inputs[1].to(device_name), inputs[2].to(device_name))
if dist.get_rank() == 0: if dist.get_rank() == 0:
load_path = os.path.join(class_tmpdir, "output.pt") load_path = os.path.join(class_tmpdir, "output.pt")
baseline = torch.load(load_path) baseline = torch.load(load_path, weights_only=False)
test = test.cpu() test = test.cpu()
assert torch.allclose( assert torch.allclose(
baseline, test, baseline, test,

Просмотреть файл

@ -225,7 +225,7 @@ class TestConfigurableResizePP(ConfigurablePP):
assert torch.is_tensor(test[0][0]) assert torch.is_tensor(test[0][0])
test = test[0][0].cpu() test = test[0][0].cpu()
load_path = os.path.join(class_tmpdir, f"output-{checkpoint_tag}.pt") load_path = os.path.join(class_tmpdir, f"output-{checkpoint_tag}.pt")
baseline = torch.load(load_path) baseline = torch.load(load_path, weights_only=False)
assert torch.allclose( assert torch.allclose(
baseline, test, baseline, test,
atol=1e-03), f"Baseline output {baseline} is not equal to save-then-load output {test}" atol=1e-03), f"Baseline output {baseline} is not equal to save-then-load output {test}"