Use msgpack for P2P communication in pipeline engine.

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
Masahiro Tanaka 2024-09-25 17:34:38 -07:00 коммит произвёл GitHub
Родитель 61de017176
Коммит 7622cd9e68
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
2 изменённых файлов: 5 добавлений и 4 удалений

Просмотреть файл

@ -3,7 +3,7 @@
# DeepSpeed Team
import pickle
import msgpack
import typing
import torch
@ -96,7 +96,7 @@ def wait():
def send_obj(msg: typing.Any, dest: int):
"""Send an arbitrary python object to ``dest``.
Note: ``msg`` must be pickleable.
Note: ``msg`` must be serializable by msgpack.
WARN: This incurs a CPU -> GPU transfer and should be used sparingly
for performance reasons.
@ -106,7 +106,7 @@ def send_obj(msg: typing.Any, dest: int):
dest (int): Destination rank.
"""
# serialize the message
msg = pickle.dumps(msg)
msg = msgpack.packb(msg)
# construct a tensor to send
msg = torch.ByteTensor(torch.ByteStorage.from_buffer(msg)).to(get_accelerator().device_name())
@ -133,7 +133,7 @@ def recv_obj(sender: int) -> typing.Any:
msg = torch.empty(length.item(), dtype=torch.uint8).to(get_accelerator().device_name())
dist.recv(msg, src=sender)
msg = pickle.loads(msg.cpu().numpy().tobytes())
msg = msgpack.unpackb(msg.cpu().numpy().tobytes())
def _to(x):
"""Recursively move to the current device."""

Просмотреть файл

@ -1,4 +1,5 @@
hjson
msgpack
ninja
numpy
packaging>=20.0