зеркало из https://github.com/microsoft/DeepSpeed.git
A faster and more memory-efficient implementation of `zero_to_fp32` (#6658)
It is a faster and more memory-efficient implementation of `zero_to_fp32`. The previous version double the memory usage, which cause cpu OOM for very large models (e.g. llama 405B).b647fb2470/deepspeed/utils/zero_to_fp32.py (L438-L441)
## How does it work? 1. **Lazy loading**: Load checkpoint with `mmap=True`, thus the weights are mmaped rather than loading all the storages into memory. 2. **Lazy merge**: `GatheredTensor` contains the mmaped weights and tensor offset. It is a memory-efficient pseudo tensor. Only when `tensor.contiguous()` is called, it starts to load related weights to memory and merge into a single tensor. 3. **Release memory in time**: Save checkpoints shard by shard, and release the memory once a shard is saved. Throughout the process, only one shard of tensors are keeped in memory. ## How much benefit in speed and memory ? Experiments were conducted on a linux host with 1TB of memory. Here is a detailed comparision | | world size | peak memory(GB) | elapsed time(h:mm:ss) | |----------------------|------------|--------------|--------------------| | llama3-8B(old->new) | 8 | 90 -> 41 | 0:02:17 -> 0:01:10 | | llama2-13B(old->new) | 8 | 146 -> 54 | 0:02:30 -> 0:01:47 | | llama2-70B(old->new) | 16 | 789 -> 159 | 0:20:47 -> 0:20:45 | | qwen1.5-110B(old->new) | 32 | OOM -> 217 | ? -> 0:34:21 | | llama3-405B(old->new) | 192 | OOM -> 262 | ? -> 2:09:59 | You can reproduce with the following scripts ```sh # 1. install requirments apt-get install time # 2. prepare zero-3 checkpoints # 3. convert zero to fp32 checkpoints /usr/bin/time -v python zero_to_fp32.py . output_dir/ --safe_serialization ``` - **memory**: Theoretically, this PR reduces the memory cost from `2M` to `(1/n)M`, where `M` is the memory cost of the full weights, `n` is num_shards. - **speed**: The speed gain mainly comes from avoiding extra tensor copying. The benifit may be slight. ## Impl history - [v1](19712a1c75 (diff-6a2ca3427fa608c387b7351359f98cfc1313be6e960cee86344ff246bf1b8326R441-R447)
) : a hf_hub compatible approach. It has been discarded due to the controversial implementation of `data_ptr().` - [v2](https://github.com/microsoft/DeepSpeed/pull/6658/files): a simple approach with `torch.empty` --------- Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
Родитель
f594dbe3df
Коммит
dd40269426
|
@ -21,7 +21,9 @@ import glob
|
|||
import math
|
||||
import os
|
||||
import re
|
||||
import gc
|
||||
import json
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
|
@ -146,8 +148,8 @@ def parse_model_states(files):
|
|||
def parse_optim_states(files, ds_checkpoint_dir):
|
||||
total_files = len(files)
|
||||
state_dicts = []
|
||||
for f in files:
|
||||
state_dict = torch.load(f, map_location=device)
|
||||
for f in tqdm(files, desc='Loading checkpoint shards'):
|
||||
state_dict = torch.load(f, map_location=device, mmap=True)
|
||||
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
|
||||
# and also handle the case where it was already removed by another helper script
|
||||
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
|
||||
|
@ -179,19 +181,7 @@ def parse_optim_states(files, ds_checkpoint_dir):
|
|||
else:
|
||||
raise ValueError(f"unknown zero stage {zero_stage}")
|
||||
|
||||
if zero_stage <= 2:
|
||||
fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
|
||||
elif zero_stage == 3:
|
||||
# if there is more than one param group, there will be multiple flattened tensors - one
|
||||
# flattened tensor per group - for simplicity merge them into a single tensor
|
||||
#
|
||||
# XXX: could make the script more memory efficient for when there are multiple groups - it
|
||||
# will require matching the sub-lists of param_shapes for each param group flattened tensor
|
||||
|
||||
fp32_flat_groups = [
|
||||
torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
|
||||
]
|
||||
|
||||
fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
|
||||
return zero_stage, world_size, fp32_flat_groups
|
||||
|
||||
|
||||
|
@ -398,9 +388,56 @@ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
|
|||
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
|
||||
|
||||
|
||||
class GatheredTensor:
|
||||
"""
|
||||
A pseudo tensor that collects partitioned weights.
|
||||
It is more memory efficient when there are multiple groups.
|
||||
"""
|
||||
|
||||
def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
|
||||
self.flat_groups = flat_groups
|
||||
self.flat_groups_offset = flat_groups_offset
|
||||
self.offset = offset
|
||||
self.partitioned_numel = partitioned_numel
|
||||
self.shape = shape
|
||||
self.dtype = self.flat_groups[0][0].dtype
|
||||
|
||||
def contiguous(self):
|
||||
"""
|
||||
Merge partitioned weights from flat_groups into a single tensor.
|
||||
"""
|
||||
end_idx = self.offset + self.partitioned_numel
|
||||
world_size = len(self.flat_groups)
|
||||
pad_flat_param_chunks = []
|
||||
|
||||
for rank_i in range(world_size):
|
||||
# for each rank, we need to collect weights from related group/groups
|
||||
flat_groups_at_rank_i = self.flat_groups[rank_i]
|
||||
start_group_id = None
|
||||
end_group_id = None
|
||||
for group_id in range(len(self.flat_groups_offset)):
|
||||
if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
|
||||
start_group_id = group_id
|
||||
if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
|
||||
end_group_id = group_id
|
||||
break
|
||||
# collect weights from related group/groups
|
||||
for group_id in range(start_group_id, end_group_id + 1):
|
||||
flat_tensor = flat_groups_at_rank_i[group_id]
|
||||
start_offset = self.offset - self.flat_groups_offset[group_id]
|
||||
end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
|
||||
pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
|
||||
|
||||
# collect weights from all ranks
|
||||
pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
|
||||
param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
|
||||
return param
|
||||
|
||||
|
||||
def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
|
||||
param_shapes = zero_model_states[0].param_shapes
|
||||
avail_numel = fp32_flat_groups[0].numel() * world_size
|
||||
avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
|
||||
|
||||
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
|
||||
# param, re-consolidating each param, while dealing with padding if any
|
||||
|
||||
|
@ -424,7 +461,8 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
|
|||
offset = 0
|
||||
total_numel = 0
|
||||
total_params = 0
|
||||
for name, shape in tqdm(param_shapes.items(), desc='Gathering Sharded Weights'):
|
||||
flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
|
||||
for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
|
||||
unpartitioned_numel = shape.numel()
|
||||
total_numel += unpartitioned_numel
|
||||
total_params += 1
|
||||
|
@ -435,10 +473,9 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
|
|||
f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
|
||||
)
|
||||
|
||||
# XXX: memory usage doubles here
|
||||
state_dict[name] = torch.cat(
|
||||
tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
|
||||
0).narrow(0, 0, unpartitioned_numel).view(shape)
|
||||
# memory efficient tensor
|
||||
tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
|
||||
state_dict[name] = tensor
|
||||
offset += partitioned_numel
|
||||
|
||||
offset *= world_size
|
||||
|
@ -473,7 +510,29 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer
|
|||
return state_dict
|
||||
|
||||
|
||||
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False):
|
||||
def to_torch_tensor(state_dict, return_empty_tensor=False):
|
||||
"""
|
||||
Convert state_dict of GatheredTensor to torch tensor
|
||||
"""
|
||||
converted_tensors = {}
|
||||
for name, tensor in state_dict.items():
|
||||
tensor_id = id(tensor)
|
||||
if tensor_id in converted_tensors:
|
||||
shared_tensor = state_dict[converted_tensors[tensor_id]]
|
||||
state_dict[name] = shared_tensor
|
||||
else:
|
||||
converted_tensors[tensor_id] = name
|
||||
if return_empty_tensor:
|
||||
state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
|
||||
else:
|
||||
state_dict[name] = tensor.contiguous()
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
|
||||
tag=None,
|
||||
exclude_frozen_parameters=False,
|
||||
lazy_mode=False):
|
||||
"""
|
||||
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
|
||||
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
|
||||
|
@ -483,14 +542,12 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_f
|
|||
- ``checkpoint_dir``: path to the desired checkpoint folder
|
||||
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
|
||||
- ``exclude_frozen_parameters``: exclude frozen parameters
|
||||
- ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
|
||||
Convert the pesduo tensor to torch tensor by ``.contiguous()``
|
||||
|
||||
Returns:
|
||||
- pytorch ``state_dict``
|
||||
|
||||
Note: this approach may not work if your application doesn't have sufficient free CPU memory and
|
||||
you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
|
||||
the checkpoint.
|
||||
|
||||
A typical usage might be ::
|
||||
|
||||
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
|
||||
|
@ -506,6 +563,16 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_f
|
|||
|
||||
If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
|
||||
|
||||
Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
|
||||
You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
|
||||
the checkpoint. Or you can load state_dict in lazy mode ::
|
||||
|
||||
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
|
||||
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
|
||||
for name, lazy_tensor in state_dict.item():
|
||||
tensor = lazy_tensor.contiguous() # to cpu
|
||||
print(name, tensor)
|
||||
# del tensor to release memory if it no longer in use
|
||||
"""
|
||||
if tag is None:
|
||||
latest_path = os.path.join(checkpoint_dir, 'latest')
|
||||
|
@ -520,7 +587,11 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_f
|
|||
if not os.path.isdir(ds_checkpoint_dir):
|
||||
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
|
||||
|
||||
return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
|
||||
state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
|
||||
if lazy_mode:
|
||||
return state_dict
|
||||
else:
|
||||
return to_torch_tensor(state_dict)
|
||||
|
||||
|
||||
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
|
||||
|
@ -541,6 +612,7 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
|
|||
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
|
||||
- ``exclude_frozen_parameters``: exclude frozen parameters
|
||||
"""
|
||||
|
||||
# Dependency pre-check
|
||||
if safe_serialization:
|
||||
try:
|
||||
|
@ -556,13 +628,18 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
|
|||
raise
|
||||
|
||||
# Convert zero checkpoint to state_dict
|
||||
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters)
|
||||
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
|
||||
tag,
|
||||
exclude_frozen_parameters,
|
||||
lazy_mode=True)
|
||||
|
||||
# Shard the model if it is too big.
|
||||
weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
|
||||
if max_shard_size is not None:
|
||||
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
|
||||
state_dict_split = split_torch_state_dict_into_shards(state_dict,
|
||||
# an memory-efficient approach for sharding
|
||||
empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
|
||||
state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
|
||||
filename_pattern=filename_pattern,
|
||||
max_shard_size=max_shard_size)
|
||||
else:
|
||||
|
@ -571,15 +648,22 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
|
|||
state_dict_split = StateDictSplit(is_sharded=False,
|
||||
filename_to_tensors={weights_name: list(state_dict.keys())})
|
||||
|
||||
# Save the model
|
||||
# Save the model by shard
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
filename_to_tensors = state_dict_split.filename_to_tensors.items()
|
||||
for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
|
||||
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
|
||||
shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
|
||||
shard_state_dict = to_torch_tensor(shard_state_dict)
|
||||
output_path = os.path.join(output_dir, shard_file)
|
||||
if safe_serialization:
|
||||
save_file(shard, output_path, metadata={"format": "pt"})
|
||||
save_file(shard_state_dict, output_path, metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(shard, output_path)
|
||||
torch.save(shard_state_dict, output_path)
|
||||
# release the memory of current shard
|
||||
for tensor_name in shard_state_dict:
|
||||
del state_dict[tensor_name]
|
||||
del shard_state_dict
|
||||
gc.collect()
|
||||
|
||||
# Save index if sharded
|
||||
if state_dict_split.is_sharded:
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import deepspeed
|
||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||
from unit.common import DistributedTest
|
||||
|
||||
|
||||
class ModelWithSharedWeights(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer0 = nn.Linear(100, 100)
|
||||
self.layer1 = nn.Linear(200, 200)
|
||||
self.layer2 = nn.Linear(300, 300)
|
||||
# tie layer 1 and layer 2
|
||||
self.layer1.weight = self.layer2.weight
|
||||
|
||||
|
||||
class TestCheckpointConvert(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
def test_convert_zero_checkpoint_to_fp32_state_dict(self, tmpdir):
|
||||
config = {
|
||||
"train_micro_batch_size_per_gpu": 2,
|
||||
"zero_allow_untested_optimizer": True,
|
||||
"zero_optimization": {
|
||||
"stage": 3
|
||||
},
|
||||
}
|
||||
model = ModelWithSharedWeights()
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
|
||||
deepspeed_engine, _, _, _ = deepspeed.initialize(
|
||||
config=config,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
ds_save_dir = tmpdir / "checkpoint_ds"
|
||||
deepspeed_engine.save_checkpoint(ds_save_dir, tag="checkpoint")
|
||||
|
||||
model = ModelWithSharedWeights()
|
||||
|
||||
# save checkpoint
|
||||
fp32_save_dir = tmpdir / "checkpoint_fp32"
|
||||
convert_zero_checkpoint_to_fp32_state_dict(ds_save_dir, fp32_save_dir)
|
||||
|
||||
# load state_dict from fp32 checkpoint
|
||||
state_dict = torch.load(fp32_save_dir / 'pytorch_model.bin')
|
||||
|
||||
# check shared tensor
|
||||
assert id(state_dict['layer1.weight']) == id(state_dict['layer2.weight'])
|
||||
|
||||
# load state_dict into model
|
||||
model.load_state_dict(state_dict, strict=True)
|
Загрузка…
Ссылка в новой задаче