From 6e6563d3c8d7527713cc48d4a3adce51f22e83a2 Mon Sep 17 00:00:00 2001 From: Lzhang-hub <57925599+Lzhang-hub@users.noreply.github.com> Date: Thu, 24 Oct 2024 04:29:30 +0800 Subject: [PATCH] fix init_device_mesh for torch 2.4 (#6614) Start torch 2.4, in [`init_device_mesh()`](https://github.com/pytorch/pytorch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/torch/distributed/device_mesh.py#L915) ,device type with a GPU index, such as "cuda:0", is not allowed. ![image](https://github.com/user-attachments/assets/1ddb61bf-8a15-4e0a-9115-a3681d7f19ff) --------- Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Co-authored-by: Masahiro Tanaka --- deepspeed/comm/torch.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index ed2645d41..988b74232 100755 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -390,9 +390,14 @@ class TorchBackend(Backend): if not required_torch_version(min_version=2.2): raise RuntimeError(f"Current torch version does not have device mesh" f"api (torch.__version__: {torch.__version__})") - return torch.distributed.device_mesh.init_device_mesh(get_accelerator().current_device_name(), - mesh_shape, - mesh_dim_names=mesh_dim_names) + if not required_torch_version(max_version=2.4): + return torch.distributed.device_mesh.init_device_mesh(get_accelerator().device_name(), + mesh_shape, + mesh_dim_names=mesh_dim_names) + else: + return torch.distributed.device_mesh.init_device_mesh(get_accelerator().current_device_name(), + mesh_shape, + mesh_dim_names=mesh_dim_names) # This will become a light-weight wrapper around torch.distributed functions