diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index ed2645d41..988b74232 100755 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -390,9 +390,14 @@ class TorchBackend(Backend): if not required_torch_version(min_version=2.2): raise RuntimeError(f"Current torch version does not have device mesh" f"api (torch.__version__: {torch.__version__})") - return torch.distributed.device_mesh.init_device_mesh(get_accelerator().current_device_name(), - mesh_shape, - mesh_dim_names=mesh_dim_names) + if not required_torch_version(max_version=2.4): + return torch.distributed.device_mesh.init_device_mesh(get_accelerator().device_name(), + mesh_shape, + mesh_dim_names=mesh_dim_names) + else: + return torch.distributed.device_mesh.init_device_mesh(get_accelerator().current_device_name(), + mesh_shape, + mesh_dim_names=mesh_dim_names) # This will become a light-weight wrapper around torch.distributed functions