зеркало из https://github.com/microsoft/DeepSpeed.git
Add explicit parameters for torch.load (#6751)
Successor PR to #6094: > FutureWarning: You are using torch.load with weights_only=False (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for weights_only will be flipped to True. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via torch.serialization.add_safe_globals. We recommend you start setting weights_only=True for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. Todo: - [ ] Update values in non-test files to True where necessary.
This commit is contained in:
Родитель
1fdad1fa52
Коммит
2e0c39b55c
|
@ -116,7 +116,7 @@ class DeepSpeedCheckpoint(object):
|
|||
self._dump_mapping(self.transformer_file_map, 'rank_to_transformer_files')
|
||||
|
||||
def _build_global_state(self):
|
||||
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
|
||||
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False)
|
||||
self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)
|
||||
self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None)
|
||||
|
||||
|
@ -137,14 +137,17 @@ class DeepSpeedCheckpoint(object):
|
|||
|
||||
def get_iteration(self):
|
||||
if not ITERATION_KEY in self.global_state:
|
||||
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
|
||||
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False)
|
||||
self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)
|
||||
|
||||
return self.global_state[ITERATION_KEY]
|
||||
|
||||
def get_embedding_state(self, tp_index: int) -> Dict:
|
||||
assert tp_index in self.tp_to_embedding_map.keys()
|
||||
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in self.tp_to_embedding_map[tp_index]]
|
||||
sd_list = [
|
||||
torch.load(fname, map_location=torch.device('cpu'), weights_only=False)
|
||||
for fname in self.tp_to_embedding_map[tp_index]
|
||||
]
|
||||
sd = self._merge_state_dicts(sd_list)
|
||||
return sd
|
||||
|
||||
|
@ -154,7 +157,7 @@ class DeepSpeedCheckpoint(object):
|
|||
|
||||
def _get_checkpoint_value(self, key):
|
||||
if not key in self.global_state:
|
||||
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
|
||||
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False)
|
||||
self.global_state[key] = sd.get(key, None)
|
||||
|
||||
return self.global_state[key]
|
||||
|
@ -169,7 +172,7 @@ class DeepSpeedCheckpoint(object):
|
|||
assert tp_index < self.tp_degree
|
||||
assert pp_index < self.pp_degree
|
||||
fname_list = self.get_2d_parallel_files(tp_index=tp_index, pp_index=pp_index)
|
||||
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list]
|
||||
sd_list = [torch.load(fname, map_location=torch.device('cpu'), weights_only=False) for fname in fname_list]
|
||||
|
||||
merged_sd = None
|
||||
for sd in sd_list:
|
||||
|
@ -185,7 +188,7 @@ class DeepSpeedCheckpoint(object):
|
|||
assert pp_index < self.pp_degree
|
||||
t_list = []
|
||||
for fname_list in self.transformer_file_map[(tp_index, pp_index)]:
|
||||
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list]
|
||||
sd_list = [torch.load(fname, map_location=torch.device('cpu'), weights_only=False) for fname in fname_list]
|
||||
sd = self._merge_state_dicts(sd_list)
|
||||
t_list.append(sd)
|
||||
return t_list
|
||||
|
@ -196,7 +199,7 @@ class DeepSpeedCheckpoint(object):
|
|||
|
||||
def get_final_norm_state(self, tp_index: int) -> Dict:
|
||||
assert tp_index in self.tp_to_final_norm_map.keys()
|
||||
sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu'))
|
||||
sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu'), weights_only=False)
|
||||
return sd
|
||||
|
||||
def get_final_norm_files(self, tp_index: int) -> list:
|
||||
|
|
|
@ -150,7 +150,7 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D):
|
|||
|
||||
|
||||
def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, dp_index):
|
||||
state_dict = torch.load(optim_files[dp_index], map_location='cpu')
|
||||
state_dict = torch.load(optim_files[dp_index], map_location='cpu', weights_only=False)
|
||||
|
||||
flat_state = dict(
|
||||
exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg"],
|
||||
|
@ -214,7 +214,7 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape=None):
|
|||
raise ValueError(f"Cannot parse dp_rank from {p}")
|
||||
|
||||
paths = [f"{prefix_path}.{dp_index_to_str(dp_index)}" for dp_index in sorted(list(dp_indices))]
|
||||
shards = [torch.load(p) for p in paths]
|
||||
shards = [torch.load(p, weights_only=False) for p in paths]
|
||||
|
||||
if state == "step":
|
||||
assert all(v == shards[0] for v in shards), "All shards must have the same step value"
|
||||
|
@ -404,7 +404,7 @@ def _zero_partitioned_param_info(unpartitioned_numel, world_size):
|
|||
|
||||
|
||||
def _parse_model_states_stage3(files):
|
||||
return torch.load(files[0], map_location=torch.device('cpu'))[PARAM_SHAPES]
|
||||
return torch.load(files[0], map_location=torch.device('cpu'), weights_only=False)[PARAM_SHAPES]
|
||||
|
||||
|
||||
def _save_optimizer_state(args, ds_checkpoint):
|
||||
|
@ -420,7 +420,7 @@ def _save_optimizer_state(args, ds_checkpoint):
|
|||
|
||||
|
||||
def _save_optimizer_state_stage3(args, optim_files):
|
||||
sd = torch.load(optim_files[0], map_location=torch.device('cpu'))
|
||||
sd = torch.load(optim_files[0], map_location=torch.device('cpu'), weights_only=False)
|
||||
output_sd = sd[OPTIMIZER_STATE_DICT]
|
||||
output_sd[PARAM_GROUPS] = output_sd[OPTIMIZER_STATE_DICT][PARAM_GROUPS]
|
||||
zero_output_folder = os.path.join(args.output_folder, "zero")
|
||||
|
@ -446,7 +446,7 @@ def _get_checkpoint_files(checkpoint_dir, glob_pattern):
|
|||
|
||||
|
||||
def _get_zero_stage(optim_files):
|
||||
state_dict = torch.load(optim_files[0], map_location=torch.device('cpu'))
|
||||
state_dict = torch.load(optim_files[0], map_location=torch.device('cpu'), weights_only=False)
|
||||
optimizer_state = state_dict[OPTIMIZER_STATE_DICT]
|
||||
zero_stage = optimizer_state.get(ZERO_STAGE, 1)
|
||||
return zero_stage
|
||||
|
@ -454,7 +454,7 @@ def _get_zero_stage(optim_files):
|
|||
|
||||
def _inject_missing_state(ds_checkpoint):
|
||||
if UNIVERSAL_CHECKPOINT_INFO not in ds_checkpoint.global_state:
|
||||
sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu'))
|
||||
sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False)
|
||||
if UNIVERSAL_CHECKPOINT_INFO not in sd:
|
||||
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO] = {}
|
||||
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO][
|
||||
|
@ -488,7 +488,7 @@ def main(args):
|
|||
|
||||
slice_shapes = []
|
||||
for mp_rank_file in ds_checkpoint.mp_rank_files:
|
||||
mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu'))
|
||||
mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu'), weights_only=False)
|
||||
slice_shapes += mp_sd[PARAM_SHAPES]
|
||||
|
||||
# fix back to normal flat dict, merge duplicates for tp>1
|
||||
|
|
|
@ -34,7 +34,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
|
|||
step = None
|
||||
for key in hp_keys:
|
||||
ckpt_file = os.path.join(folder, f"{key}.pt")
|
||||
ckpt_dict = torch.load(ckpt_file)
|
||||
ckpt_dict = torch.load(ckpt_file, weights_only=False)
|
||||
|
||||
if key == "step":
|
||||
step = ckpt_dict
|
||||
|
|
|
@ -54,7 +54,7 @@ class ZeROCheckpoint(object):
|
|||
state_file_list = self.get_files_for_rank(pp_index, tp_index, dp_index)
|
||||
merged_sd = None
|
||||
for state_file in state_file_list:
|
||||
sd = torch.load(state_file, map_location=torch.device('cpu'))
|
||||
sd = torch.load(state_file, map_location=torch.device('cpu'), weights_only=False)
|
||||
for key in keys_to_ignore:
|
||||
sd.pop(key, None)
|
||||
|
||||
|
|
|
@ -452,7 +452,7 @@ class InferenceEngine(Module):
|
|||
checkpoint = sd_loader['checkpoints']
|
||||
|
||||
if type(checkpoint) is list:
|
||||
self.sd = torch.load(checkpoint[0], map_location='cpu')
|
||||
self.sd = torch.load(checkpoint[0], map_location='cpu', weights_only=False)
|
||||
self.key_list = list(self.sd.keys())
|
||||
|
||||
self.load_model_with_checkpoint(self.module)
|
||||
|
@ -460,7 +460,7 @@ class InferenceEngine(Module):
|
|||
for i in range(1, len(checkpoint)):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
print(f"loading checkpoint ({i})")
|
||||
self.sd = torch.load(checkpoint[i], map_location=get_accelerator().device_name())
|
||||
self.sd = torch.load(checkpoint[i], map_location=get_accelerator().device_name(), weights_only=False)
|
||||
self.key_list = list(self.sd.keys())
|
||||
self.load_model_with_checkpoint(self.module)
|
||||
else:
|
||||
|
|
|
@ -80,7 +80,7 @@ class HuggingFaceCheckpointEngine(CheckpointEngineBase):
|
|||
else:
|
||||
model_param_json_fname = "pytorch_model.bin.index.json"
|
||||
model_file_fname = "pytorch_model.bin"
|
||||
self._checkpoint_load_fn = partial(torch.load, map_location="cpu")
|
||||
self._checkpoint_load_fn = partial(torch.load, map_location="cpu", weights_only=False)
|
||||
|
||||
model_param_json = os.path.join(self._local_checkpoint_dir, model_param_json_fname)
|
||||
|
||||
|
|
|
@ -205,7 +205,7 @@ class InferenceV2Policy(ABC, metaclass=PolicyMeta):
|
|||
buffer_path = make_param_filename(self._inf_checkpoint_path, self.model.tp_rank, self.model.tp_size)
|
||||
metadata_path = make_metadata_filename(self._inf_checkpoint_path, self.model.tp_rank, self.model.tp_size)
|
||||
|
||||
buffer = torch.load(buffer_path)
|
||||
buffer = torch.load(buffer_path, weights_only=False)
|
||||
metadata = json.load(open(metadata_path, "r"))
|
||||
metadata = ModelMetadata.parse_raw(metadata)
|
||||
|
||||
|
|
|
@ -415,7 +415,7 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m
|
|||
pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards")
|
||||
|
||||
for i in range(len(checkpoint)):
|
||||
sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu')]
|
||||
sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu', weights_only=False)]
|
||||
load_model_with_checkpoint(replaced_module,
|
||||
sd,
|
||||
mp_replace,
|
||||
|
@ -437,7 +437,7 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m
|
|||
os.path.join(base_dir1, ckpt_list[ckpt_index + j]) if base_dir1 else ckpt_list[ckpt_index + j]
|
||||
for j in range(sd_count)
|
||||
]
|
||||
sds = [torch.load(ckpt_file, map_location='cpu') for ckpt_file in ckpt_files]
|
||||
sds = [torch.load(ckpt_file, map_location='cpu', weights_only=False) for ckpt_file in ckpt_files]
|
||||
load_model_with_checkpoint(replaced_module,
|
||||
sds,
|
||||
mp_replace,
|
||||
|
@ -457,7 +457,7 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m
|
|||
pbar.update(1)
|
||||
ckpt_file = os.path.join(base_dir1,
|
||||
checkpoint["non_tp"][i]) if base_dir1 else checkpoint["non_tp"][i]
|
||||
sds = [torch.load(ckpt_file, map_location='cpu')]
|
||||
sds = [torch.load(ckpt_file, map_location='cpu', weights_only=False)]
|
||||
load_model_with_checkpoint(replaced_module,
|
||||
sds,
|
||||
mp_replace,
|
||||
|
@ -624,7 +624,7 @@ def replace_module(model, orig_class, replace_fn, _replace_policy, checkpoint=No
|
|||
from safetensors.torch import load_file
|
||||
sd = load_file(checkpoint)
|
||||
else:
|
||||
sd = torch.load(checkpoint, map_location='cpu')
|
||||
sd = torch.load(checkpoint, map_location='cpu', weights_only=False)
|
||||
|
||||
policy = {}
|
||||
if orig_class is not None:
|
||||
|
|
|
@ -22,7 +22,7 @@ class ZeROOptimizer(DeepSpeedOptimizer):
|
|||
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
|
||||
assert os.path.isfile(
|
||||
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'
|
||||
optim_sd = torch.load(optim_state_path)
|
||||
optim_sd = torch.load(optim_state_path, weights_only=False)
|
||||
|
||||
self._load_global_state(optim_sd)
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ class NebulaCheckpointEngine(CheckpointEngine):
|
|||
if not self.enable_nebula_load and first_load_flag:
|
||||
self.tag_flag = tag
|
||||
logger.info(f"[Nebula] Disable nebula load. Loading checkpoint from {path} ...")
|
||||
partition = torch.load(path, map_location=map_location)
|
||||
partition = torch.load(path, map_location=map_location, weights_only=False)
|
||||
logger.info(f"[Nebula] Disable nebula load. Loaded checkpoint from {path} .")
|
||||
return partition
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ class TorchCheckpointEngine(CheckpointEngine):
|
|||
|
||||
def load(self, path: str, map_location=None):
|
||||
logger.info(f"[Torch] Loading checkpoint from {path}...")
|
||||
partition = torch.load(path, map_location=map_location)
|
||||
partition = torch.load(path, map_location=map_location, weights_only=False)
|
||||
logger.info(f"[Torch] Loaded checkpoint from {path}.")
|
||||
return partition
|
||||
|
||||
|
|
|
@ -2741,7 +2741,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
|||
assert os.path.isfile(
|
||||
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'
|
||||
|
||||
optim_sd = torch.load(optim_state_path)
|
||||
optim_sd = torch.load(optim_state_path, weights_only=False)
|
||||
self._load_global_state_stage3(optim_sd)
|
||||
|
||||
key_list = ["fp32", "exp_avg", "exp_avg_sq"]
|
||||
|
@ -2799,7 +2799,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
|||
local_rank = dist.get_local_rank()
|
||||
|
||||
# Load tensors from files and reshape them to flat vectors
|
||||
loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt")).view(-1)
|
||||
loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False).view(-1)
|
||||
|
||||
# Partition the loaded data according to the local rank
|
||||
world_size = dist.get_world_size(group=self.dp_process_group)
|
||||
|
|
|
@ -102,7 +102,7 @@ def get_model_state_files(checkpoint_dir):
|
|||
def parse_model_states(files):
|
||||
zero_model_states = []
|
||||
for file in files:
|
||||
state_dict = torch.load(file, map_location=device)
|
||||
state_dict = torch.load(file, map_location=device, weights_only=False)
|
||||
|
||||
if BUFFER_NAMES not in state_dict:
|
||||
raise ValueError(f"{file} is not a model state checkpoint")
|
||||
|
@ -149,7 +149,7 @@ def parse_optim_states(files, ds_checkpoint_dir):
|
|||
total_files = len(files)
|
||||
state_dicts = []
|
||||
for f in tqdm(files, desc='Loading checkpoint shards'):
|
||||
state_dict = torch.load(f, map_location=device, mmap=True)
|
||||
state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
|
||||
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
|
||||
# and also handle the case where it was already removed by another helper script
|
||||
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
|
||||
|
|
|
@ -218,7 +218,7 @@ def checkpoint_correctness_verification(config_dict,
|
|||
for root, _, files in os.walk(save_folder):
|
||||
for f in files:
|
||||
if "_expert_" in f and "_model_states" in f:
|
||||
expert = torch.load(os.path.join(root, f))
|
||||
expert = torch.load(os.path.join(root, f), weights_only=False)
|
||||
needed, storages = 0, {}
|
||||
for name, tensor in expert.items():
|
||||
needed += tensor.size().numel()
|
||||
|
|
|
@ -181,7 +181,7 @@ class TestZeROUniversalCheckpointDP(DistributedTest):
|
|||
)
|
||||
|
||||
hidden_dim = 10
|
||||
loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt")
|
||||
loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt", weights_only=False)
|
||||
|
||||
ds_config["checkpoint"] = {"load_universal": True}
|
||||
univ_model = SimpleModel(hidden_dim)
|
||||
|
|
|
@ -264,7 +264,7 @@ class TestZeROElasticCheckpoint(DistributedTest):
|
|||
model.load_checkpoint(tmpdir, load_optimizer_states=load_optim)
|
||||
|
||||
if load_optim:
|
||||
saved_sd = torch.load(os.path.join(tmpdir, opt_state_dict_file))
|
||||
saved_sd = torch.load(os.path.join(tmpdir, opt_state_dict_file), weights_only=False)
|
||||
curr_sd = model.optimizer.optimizer.state_dict()
|
||||
compare_opt_state_dicts(curr_sd, saved_sd, expected_mismatch_keys)
|
||||
|
||||
|
@ -523,7 +523,7 @@ class TestZeROCheckpointFrozenWeights(DistributedTest):
|
|||
all_ckpt_folder = os.path.join(tmpdir, 'all_params')
|
||||
ds_engine.save_checkpoint(all_ckpt_folder)
|
||||
all_params_ckpt_file = get_model_ckpt_name_for_rank(os.path.join(all_ckpt_folder, 'global_step0'), '00')
|
||||
loaded_all_param_model = torch.load(all_params_ckpt_file)['module']
|
||||
loaded_all_param_model = torch.load(all_params_ckpt_file, weights_only=False)['module']
|
||||
all_param_names = set([n for n, p in model.named_parameters()])
|
||||
assert set(loaded_all_param_model.keys()) == all_param_names
|
||||
|
||||
|
@ -536,7 +536,7 @@ class TestZeROCheckpointFrozenWeights(DistributedTest):
|
|||
# Excluding frozen parameters should reduce checkpoint size
|
||||
assert os.path.getsize(all_params_ckpt_file) > os.path.getsize(trainable_ckpt_file)
|
||||
|
||||
loaded_trainable_param_model = torch.load(trainable_ckpt_file)['module']
|
||||
loaded_trainable_param_model = torch.load(trainable_ckpt_file, weights_only=False)['module']
|
||||
frozen_param_names = set([n for n, p in model.named_parameters() if not p.requires_grad])
|
||||
loaded_trainable_param_names = set(loaded_trainable_param_model.keys())
|
||||
overlap_names = set.intersection(loaded_trainable_param_names, frozen_param_names)
|
||||
|
@ -575,7 +575,7 @@ class TestZeROCheckpointFrozenWeights(DistributedTest):
|
|||
|
||||
custom_state_dict_ckpt_file = get_model_ckpt_name_for_rank(
|
||||
os.path.join(custom_state_dict_ckpt_folder, 'global_step0'), '00')
|
||||
loaded_custom_state_dict_param_model = torch.load(custom_state_dict_ckpt_file)['module']
|
||||
loaded_custom_state_dict_param_model = torch.load(custom_state_dict_ckpt_file, weights_only=False)['module']
|
||||
loaded_custom_state_dict_param_names = set(loaded_custom_state_dict_param_model.keys())
|
||||
|
||||
custom_state_dict_param_names = set([k for k, v in model.state_dict().items()])
|
||||
|
@ -618,7 +618,8 @@ class TestSaveTensorClone(DistributedTest):
|
|||
clone_ckpt_file = os.path.join(tmpdir, 'clone_ckpt.pt')
|
||||
torch.save(clone_state_dict, clone_ckpt_file)
|
||||
|
||||
compare_state_dicts(torch.load(ref_ckpt_file), torch.load(clone_ckpt_file))
|
||||
compare_state_dicts(torch.load(ref_ckpt_file, weights_only=False),
|
||||
torch.load(clone_ckpt_file, weights_only=False))
|
||||
|
||||
|
||||
class TestZeRONonDistributed(DistributedTest):
|
||||
|
|
|
@ -170,7 +170,7 @@ class TestConfigurableResizeMP(ConfigurableMP):
|
|||
test = model(inputs[0].to(device_name), inputs[1].to(device_name), inputs[2].to(device_name))
|
||||
if dist.get_rank() == 0:
|
||||
load_path = os.path.join(class_tmpdir, "output.pt")
|
||||
baseline = torch.load(load_path)
|
||||
baseline = torch.load(load_path, weights_only=False)
|
||||
test = test.cpu()
|
||||
assert torch.allclose(
|
||||
baseline, test,
|
||||
|
|
|
@ -225,7 +225,7 @@ class TestConfigurableResizePP(ConfigurablePP):
|
|||
assert torch.is_tensor(test[0][0])
|
||||
test = test[0][0].cpu()
|
||||
load_path = os.path.join(class_tmpdir, f"output-{checkpoint_tag}.pt")
|
||||
baseline = torch.load(load_path)
|
||||
baseline = torch.load(load_path, weights_only=False)
|
||||
assert torch.allclose(
|
||||
baseline, test,
|
||||
atol=1e-03), f"Baseline output {baseline} is not equal to save-then-load output {test}"
|
||||
|
|
Загрузка…
Ссылка в новой задаче