зеркало из https://github.com/microsoft/DeepSpeed.git
Use file store for tests (#6632)
This PR changes the `init_method` for tests to `FileStore` for robustness.
This commit is contained in:
Родитель
a36db9cc1c
Коммит
c9fc34a4be
|
@ -147,16 +147,13 @@ class DistributedExec(ABC):
|
|||
def run(self):
|
||||
...
|
||||
|
||||
def __call__(self, request=None):
|
||||
def __call__(self, request):
|
||||
self._fixture_kwargs = self._get_fixture_kwargs(request, self.run)
|
||||
world_size = self.world_size
|
||||
if self.requires_cuda_env and not get_accelerator().is_available():
|
||||
pytest.skip("only supported in accelerator environments.")
|
||||
|
||||
if isinstance(world_size, int):
|
||||
world_size = [world_size]
|
||||
for procs in world_size:
|
||||
self._launch_procs(procs)
|
||||
self._launch_with_file_store(request, world_size)
|
||||
|
||||
def _get_fixture_kwargs(self, request, func):
|
||||
if not request:
|
||||
|
@ -172,7 +169,7 @@ class DistributedExec(ABC):
|
|||
pass # test methods can have kwargs that are not fixtures
|
||||
return fixture_kwargs
|
||||
|
||||
def _launch_daemonic_procs(self, num_procs):
|
||||
def _launch_daemonic_procs(self, num_procs, init_method):
|
||||
# Create process pool or use cached one
|
||||
master_port = None
|
||||
|
||||
|
@ -198,7 +195,7 @@ class DistributedExec(ABC):
|
|||
master_port = get_master_port()
|
||||
|
||||
# Run the test
|
||||
args = [(local_rank, num_procs, master_port) for local_rank in range(num_procs)]
|
||||
args = [(local_rank, num_procs, master_port, init_method) for local_rank in range(num_procs)]
|
||||
skip_msgs_async = pool.starmap_async(self._dist_run, args)
|
||||
|
||||
try:
|
||||
|
@ -218,7 +215,7 @@ class DistributedExec(ABC):
|
|||
assert len(set(skip_msgs)) == 1, "Multiple different skip messages received"
|
||||
pytest.skip(skip_msgs[0])
|
||||
|
||||
def _launch_non_daemonic_procs(self, num_procs):
|
||||
def _launch_non_daemonic_procs(self, num_procs, init_method):
|
||||
assert not self.reuse_dist_env, "Cannot reuse distributed environment with non-daemonic processes"
|
||||
|
||||
master_port = get_master_port()
|
||||
|
@ -227,7 +224,7 @@ class DistributedExec(ABC):
|
|||
prev_start_method = mp.get_start_method()
|
||||
mp.set_start_method('spawn', force=True)
|
||||
for local_rank in range(num_procs):
|
||||
p = mp.Process(target=self._dist_run, args=(local_rank, num_procs, master_port, skip_msg))
|
||||
p = mp.Process(target=self._dist_run, args=(local_rank, num_procs, master_port, init_method, skip_msg))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
mp.set_start_method(prev_start_method, force=True)
|
||||
|
@ -269,7 +266,7 @@ class DistributedExec(ABC):
|
|||
# add a check here to assert all exit messages are equal
|
||||
pytest.skip(skip_msg.get())
|
||||
|
||||
def _launch_procs(self, num_procs):
|
||||
def _launch_procs(self, num_procs, init_method):
|
||||
# Verify we have enough accelerator devices to run this test
|
||||
if get_accelerator().is_available() and get_accelerator().device_count() < num_procs:
|
||||
pytest.skip(
|
||||
|
@ -284,11 +281,11 @@ class DistributedExec(ABC):
|
|||
mp.set_start_method('forkserver', force=True)
|
||||
|
||||
if self.non_daemonic_procs:
|
||||
self._launch_non_daemonic_procs(num_procs)
|
||||
self._launch_non_daemonic_procs(num_procs, init_method)
|
||||
else:
|
||||
self._launch_daemonic_procs(num_procs)
|
||||
self._launch_daemonic_procs(num_procs, init_method)
|
||||
|
||||
def _dist_run(self, local_rank, num_procs, master_port, skip_msg=""):
|
||||
def _dist_run(self, local_rank, num_procs, master_port, init_method, skip_msg=""):
|
||||
if not dist.is_initialized():
|
||||
""" Initialize deepspeed.comm and execute the user function. """
|
||||
if self.set_dist_env:
|
||||
|
@ -312,7 +309,10 @@ class DistributedExec(ABC):
|
|||
get_accelerator().set_device(local_rank)
|
||||
|
||||
if self.init_distributed:
|
||||
deepspeed.init_distributed(dist_backend=self.backend)
|
||||
deepspeed.init_distributed(dist_backend=self.backend,
|
||||
init_method=init_method,
|
||||
rank=local_rank,
|
||||
world_size=num_procs)
|
||||
dist.barrier()
|
||||
|
||||
try:
|
||||
|
@ -328,6 +328,22 @@ class DistributedExec(ABC):
|
|||
|
||||
return skip_msg
|
||||
|
||||
def _launch_with_file_store(self, request, world_size):
|
||||
tmpdir = request.getfixturevalue("tmpdir")
|
||||
dist_file_store = tmpdir.join("dist_file_store")
|
||||
assert not os.path.exists(dist_file_store)
|
||||
init_method = f"file://{dist_file_store}"
|
||||
|
||||
if isinstance(world_size, int):
|
||||
world_size = [world_size]
|
||||
for procs in world_size:
|
||||
try:
|
||||
self._launch_procs(procs, init_method)
|
||||
finally:
|
||||
if os.path.exists(dist_file_store):
|
||||
os.remove(dist_file_store)
|
||||
time.sleep(0.5)
|
||||
|
||||
def _dist_destroy(self):
|
||||
if (dist is not None) and dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
@ -473,11 +489,7 @@ class DistributedTest(DistributedExec):
|
|||
else:
|
||||
world_size = self._fixture_kwargs.get("world_size", self.world_size)
|
||||
|
||||
if isinstance(world_size, int):
|
||||
world_size = [world_size]
|
||||
for procs in world_size:
|
||||
self._launch_procs(procs)
|
||||
time.sleep(0.5)
|
||||
self._launch_with_file_store(request, world_size)
|
||||
|
||||
def _get_current_test_func(self, request):
|
||||
# DistributedTest subclasses may have multiple test methods
|
||||
|
|
Загрузка…
Ссылка в новой задаче