This PR changes the `init_method` for tests to `FileStore` for
robustness.
This commit is contained in:
Masahiro Tanaka 2024-10-17 15:15:25 -07:00 коммит произвёл GitHub
Родитель a36db9cc1c
Коммит c9fc34a4be
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
1 изменённых файлов: 31 добавлений и 19 удалений

Просмотреть файл

@ -147,16 +147,13 @@ class DistributedExec(ABC):
def run(self):
...
def __call__(self, request=None):
def __call__(self, request):
self._fixture_kwargs = self._get_fixture_kwargs(request, self.run)
world_size = self.world_size
if self.requires_cuda_env and not get_accelerator().is_available():
pytest.skip("only supported in accelerator environments.")
if isinstance(world_size, int):
world_size = [world_size]
for procs in world_size:
self._launch_procs(procs)
self._launch_with_file_store(request, world_size)
def _get_fixture_kwargs(self, request, func):
if not request:
@ -172,7 +169,7 @@ class DistributedExec(ABC):
pass # test methods can have kwargs that are not fixtures
return fixture_kwargs
def _launch_daemonic_procs(self, num_procs):
def _launch_daemonic_procs(self, num_procs, init_method):
# Create process pool or use cached one
master_port = None
@ -198,7 +195,7 @@ class DistributedExec(ABC):
master_port = get_master_port()
# Run the test
args = [(local_rank, num_procs, master_port) for local_rank in range(num_procs)]
args = [(local_rank, num_procs, master_port, init_method) for local_rank in range(num_procs)]
skip_msgs_async = pool.starmap_async(self._dist_run, args)
try:
@ -218,7 +215,7 @@ class DistributedExec(ABC):
assert len(set(skip_msgs)) == 1, "Multiple different skip messages received"
pytest.skip(skip_msgs[0])
def _launch_non_daemonic_procs(self, num_procs):
def _launch_non_daemonic_procs(self, num_procs, init_method):
assert not self.reuse_dist_env, "Cannot reuse distributed environment with non-daemonic processes"
master_port = get_master_port()
@ -227,7 +224,7 @@ class DistributedExec(ABC):
prev_start_method = mp.get_start_method()
mp.set_start_method('spawn', force=True)
for local_rank in range(num_procs):
p = mp.Process(target=self._dist_run, args=(local_rank, num_procs, master_port, skip_msg))
p = mp.Process(target=self._dist_run, args=(local_rank, num_procs, master_port, init_method, skip_msg))
p.start()
processes.append(p)
mp.set_start_method(prev_start_method, force=True)
@ -269,7 +266,7 @@ class DistributedExec(ABC):
# add a check here to assert all exit messages are equal
pytest.skip(skip_msg.get())
def _launch_procs(self, num_procs):
def _launch_procs(self, num_procs, init_method):
# Verify we have enough accelerator devices to run this test
if get_accelerator().is_available() and get_accelerator().device_count() < num_procs:
pytest.skip(
@ -284,11 +281,11 @@ class DistributedExec(ABC):
mp.set_start_method('forkserver', force=True)
if self.non_daemonic_procs:
self._launch_non_daemonic_procs(num_procs)
self._launch_non_daemonic_procs(num_procs, init_method)
else:
self._launch_daemonic_procs(num_procs)
self._launch_daemonic_procs(num_procs, init_method)
def _dist_run(self, local_rank, num_procs, master_port, skip_msg=""):
def _dist_run(self, local_rank, num_procs, master_port, init_method, skip_msg=""):
if not dist.is_initialized():
""" Initialize deepspeed.comm and execute the user function. """
if self.set_dist_env:
@ -312,7 +309,10 @@ class DistributedExec(ABC):
get_accelerator().set_device(local_rank)
if self.init_distributed:
deepspeed.init_distributed(dist_backend=self.backend)
deepspeed.init_distributed(dist_backend=self.backend,
init_method=init_method,
rank=local_rank,
world_size=num_procs)
dist.barrier()
try:
@ -328,6 +328,22 @@ class DistributedExec(ABC):
return skip_msg
def _launch_with_file_store(self, request, world_size):
tmpdir = request.getfixturevalue("tmpdir")
dist_file_store = tmpdir.join("dist_file_store")
assert not os.path.exists(dist_file_store)
init_method = f"file://{dist_file_store}"
if isinstance(world_size, int):
world_size = [world_size]
for procs in world_size:
try:
self._launch_procs(procs, init_method)
finally:
if os.path.exists(dist_file_store):
os.remove(dist_file_store)
time.sleep(0.5)
def _dist_destroy(self):
if (dist is not None) and dist.is_initialized():
dist.barrier()
@ -473,11 +489,7 @@ class DistributedTest(DistributedExec):
else:
world_size = self._fixture_kwargs.get("world_size", self.world_size)
if isinstance(world_size, int):
world_size = [world_size]
for procs in world_size:
self._launch_procs(procs)
time.sleep(0.5)
self._launch_with_file_store(request, world_size)
def _get_current_test_func(self, request):
# DistributedTest subclasses may have multiple test methods