A faster and more memory-efficient implementation of `zero_to_fp32` (#6658)

It is a faster and more memory-efficient implementation of
`zero_to_fp32`.


The previous version double the memory usage, which cause cpu OOM for
very large models (e.g. llama 405B).

b647fb2470/deepspeed/utils/zero_to_fp32.py (L438-L441)


## How does it work?

1. **Lazy loading**: Load checkpoint with `mmap=True`, thus the weights
are mmaped rather than loading all the storages into memory.
2. **Lazy merge**: `GatheredTensor` contains the mmaped weights and
tensor offset. It is a memory-efficient pseudo tensor. Only when
`tensor.contiguous()` is called, it starts to load related weights to
memory and merge into a single tensor.
3. **Release memory in time**: Save checkpoints shard by shard, and
release the memory once a shard is saved.


Throughout the process, only one shard of tensors are keeped in memory.

## How much benefit in speed and memory ?

Experiments were conducted on a linux host with 1TB of memory. Here is a
detailed comparision
| | world size | peak memory(GB) | elapsed time(h:mm:ss) |

|----------------------|------------|--------------|--------------------|
| llama3-8B(old->new)  | 8          | 90 -> 41 | 0:02:17 -> 0:01:10 |
| llama2-13B(old->new)  | 8        | 146 -> 54 | 0:02:30 -> 0:01:47  |
| llama2-70B(old->new)  | 16        | 789 -> 159 | 0:20:47 -> 0:20:45 |
| qwen1.5-110B(old->new)  | 32       | OOM -> 217 | ? -> 0:34:21 |
| llama3-405B(old->new)  | 192      | OOM -> 262 | ? -> 2:09:59 |



You can reproduce with the following scripts
```sh
# 1. install requirments
apt-get install time
# 2. prepare zero-3 checkpoints
# 3. convert zero to fp32 checkpoints
/usr/bin/time -v python zero_to_fp32.py . output_dir/ --safe_serialization
```

- **memory**: Theoretically, this PR reduces the memory cost from `2M`
to `(1/n)M`, where `M` is the memory cost of the full weights, `n` is
num_shards.
- **speed**: The speed gain mainly comes from avoiding extra tensor
copying. The benifit may be slight.




## Impl history

-
[v1](19712a1c75 (diff-6a2ca3427fa608c387b7351359f98cfc1313be6e960cee86344ff246bf1b8326R441-R447))
: a hf_hub compatible approach.
It has been discarded due to the controversial implementation of
`data_ptr().`
- [v2](https://github.com/microsoft/DeepSpeed/pull/6658/files): a simple
approach with `torch.empty`

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
Xu Song 2024-11-19 04:14:35 +08:00 коммит произвёл GitHub
Родитель f594dbe3df
Коммит dd40269426
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
2 изменённых файлов: 177 добавлений и 33 удалений

Просмотреть файл

@ -21,7 +21,9 @@ import glob
import math
import os
import re
import gc
import json
import numpy as np
from tqdm import tqdm
from collections import OrderedDict
from dataclasses import dataclass
@ -146,8 +148,8 @@ def parse_model_states(files):
def parse_optim_states(files, ds_checkpoint_dir):
total_files = len(files)
state_dicts = []
for f in files:
state_dict = torch.load(f, map_location=device)
for f in tqdm(files, desc='Loading checkpoint shards'):
state_dict = torch.load(f, map_location=device, mmap=True)
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
# and also handle the case where it was already removed by another helper script
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
@ -179,19 +181,7 @@ def parse_optim_states(files, ds_checkpoint_dir):
else:
raise ValueError(f"unknown zero stage {zero_stage}")
if zero_stage <= 2:
fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
elif zero_stage == 3:
# if there is more than one param group, there will be multiple flattened tensors - one
# flattened tensor per group - for simplicity merge them into a single tensor
#
# XXX: could make the script more memory efficient for when there are multiple groups - it
# will require matching the sub-lists of param_shapes for each param group flattened tensor
fp32_flat_groups = [
torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
]
fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
return zero_stage, world_size, fp32_flat_groups
@ -398,9 +388,56 @@ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
class GatheredTensor:
"""
A pseudo tensor that collects partitioned weights.
It is more memory efficient when there are multiple groups.
"""
def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
self.flat_groups = flat_groups
self.flat_groups_offset = flat_groups_offset
self.offset = offset
self.partitioned_numel = partitioned_numel
self.shape = shape
self.dtype = self.flat_groups[0][0].dtype
def contiguous(self):
"""
Merge partitioned weights from flat_groups into a single tensor.
"""
end_idx = self.offset + self.partitioned_numel
world_size = len(self.flat_groups)
pad_flat_param_chunks = []
for rank_i in range(world_size):
# for each rank, we need to collect weights from related group/groups
flat_groups_at_rank_i = self.flat_groups[rank_i]
start_group_id = None
end_group_id = None
for group_id in range(len(self.flat_groups_offset)):
if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
start_group_id = group_id
if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
end_group_id = group_id
break
# collect weights from related group/groups
for group_id in range(start_group_id, end_group_id + 1):
flat_tensor = flat_groups_at_rank_i[group_id]
start_offset = self.offset - self.flat_groups_offset[group_id]
end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
# collect weights from all ranks
pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
return param
def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
param_shapes = zero_model_states[0].param_shapes
avail_numel = fp32_flat_groups[0].numel() * world_size
avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
# param, re-consolidating each param, while dealing with padding if any
@ -424,7 +461,8 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
offset = 0
total_numel = 0
total_params = 0
for name, shape in tqdm(param_shapes.items(), desc='Gathering Sharded Weights'):
flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
total_params += 1
@ -435,10 +473,9 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
)
# XXX: memory usage doubles here
state_dict[name] = torch.cat(
tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
0).narrow(0, 0, unpartitioned_numel).view(shape)
# memory efficient tensor
tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
state_dict[name] = tensor
offset += partitioned_numel
offset *= world_size
@ -473,7 +510,29 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer
return state_dict
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False):
def to_torch_tensor(state_dict, return_empty_tensor=False):
"""
Convert state_dict of GatheredTensor to torch tensor
"""
converted_tensors = {}
for name, tensor in state_dict.items():
tensor_id = id(tensor)
if tensor_id in converted_tensors:
shared_tensor = state_dict[converted_tensors[tensor_id]]
state_dict[name] = shared_tensor
else:
converted_tensors[tensor_id] = name
if return_empty_tensor:
state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
else:
state_dict[name] = tensor.contiguous()
return state_dict
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
tag=None,
exclude_frozen_parameters=False,
lazy_mode=False):
"""
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
@ -483,14 +542,12 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_f
- ``checkpoint_dir``: path to the desired checkpoint folder
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
- ``exclude_frozen_parameters``: exclude frozen parameters
- ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
Convert the pesduo tensor to torch tensor by ``.contiguous()``
Returns:
- pytorch ``state_dict``
Note: this approach may not work if your application doesn't have sufficient free CPU memory and
you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
the checkpoint.
A typical usage might be ::
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
@ -506,6 +563,16 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_f
If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
the checkpoint. Or you can load state_dict in lazy mode ::
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
for name, lazy_tensor in state_dict.item():
tensor = lazy_tensor.contiguous() # to cpu
print(name, tensor)
# del tensor to release memory if it no longer in use
"""
if tag is None:
latest_path = os.path.join(checkpoint_dir, 'latest')
@ -520,7 +587,11 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_f
if not os.path.isdir(ds_checkpoint_dir):
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
if lazy_mode:
return state_dict
else:
return to_torch_tensor(state_dict)
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
@ -541,6 +612,7 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
- ``exclude_frozen_parameters``: exclude frozen parameters
"""
# Dependency pre-check
if safe_serialization:
try:
@ -556,13 +628,18 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
raise
# Convert zero checkpoint to state_dict
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters)
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
tag,
exclude_frozen_parameters,
lazy_mode=True)
# Shard the model if it is too big.
weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
if max_shard_size is not None:
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
state_dict_split = split_torch_state_dict_into_shards(state_dict,
# an memory-efficient approach for sharding
empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
filename_pattern=filename_pattern,
max_shard_size=max_shard_size)
else:
@ -571,15 +648,22 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
state_dict_split = StateDictSplit(is_sharded=False,
filename_to_tensors={weights_name: list(state_dict.keys())})
# Save the model
# Save the model by shard
os.makedirs(output_dir, exist_ok=True)
filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
shard_state_dict = to_torch_tensor(shard_state_dict)
output_path = os.path.join(output_dir, shard_file)
if safe_serialization:
save_file(shard, output_path, metadata={"format": "pt"})
save_file(shard_state_dict, output_path, metadata={"format": "pt"})
else:
torch.save(shard, output_path)
torch.save(shard_state_dict, output_path)
# release the memory of current shard
for tensor_name in shard_state_dict:
del state_dict[tensor_name]
del shard_state_dict
gc.collect()
# Save index if sharded
if state_dict_split.is_sharded:

Просмотреть файл

@ -0,0 +1,60 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import torch.nn as nn
import deepspeed
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from unit.common import DistributedTest
class ModelWithSharedWeights(nn.Module):
def __init__(self):
super().__init__()
self.layer0 = nn.Linear(100, 100)
self.layer1 = nn.Linear(200, 200)
self.layer2 = nn.Linear(300, 300)
# tie layer 1 and layer 2
self.layer1.weight = self.layer2.weight
class TestCheckpointConvert(DistributedTest):
world_size = 2
def test_convert_zero_checkpoint_to_fp32_state_dict(self, tmpdir):
config = {
"train_micro_batch_size_per_gpu": 2,
"zero_allow_untested_optimizer": True,
"zero_optimization": {
"stage": 3
},
}
model = ModelWithSharedWeights()
optimizer = torch.optim.Adam(model.parameters())
deepspeed_engine, _, _, _ = deepspeed.initialize(
config=config,
model=model,
optimizer=optimizer,
)
ds_save_dir = tmpdir / "checkpoint_ds"
deepspeed_engine.save_checkpoint(ds_save_dir, tag="checkpoint")
model = ModelWithSharedWeights()
# save checkpoint
fp32_save_dir = tmpdir / "checkpoint_fp32"
convert_zero_checkpoint_to_fp32_state_dict(ds_save_dir, fp32_save_dir)
# load state_dict from fp32 checkpoint
state_dict = torch.load(fp32_save_dir / 'pytorch_model.bin')
# check shared tensor
assert id(state_dict['layer1.weight']) == id(state_dict['layer2.weight'])
# load state_dict into model
model.load_state_dict(state_dict, strict=True)