add sharded loading for safetensors in AutoTP (#4854)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
This commit is contained in:
Wang, Yi 2024-01-06 04:27:52 +08:00 коммит произвёл GitHub
Родитель c84c28d23b
Коммит c8c57b8c24
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 7 добавлений и 1 удалений

Просмотреть файл

@ -566,7 +566,12 @@ def replace_module(model, orig_class, replace_fn, _replace_policy, checkpoint=No
"""
sd = None
if checkpoint is not None:
sd = torch.load(checkpoint, map_location='cpu')
if checkpoint.endswith(".safetensors"):
from safetensors.torch import load_file
sd = load_file(checkpoint)
else:
sd = torch.load(checkpoint, map_location='cpu')
policy = {}
if orig_class is not None:
policy.update({orig_class: (replace_fn, _replace_policy)})

Просмотреть файл

@ -1,5 +1,6 @@
google
lm-eval==0.3.0
protobuf
safetensors
transformers>=4.32.1
transformers[sentencepiece]