superbenchmark/tests/benchmarks/utils.py

65 строки
2.1 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Utilities for benchmark tests."""
import os
import multiprocessing as multiprocessing
from multiprocessing import Process
from superbench.benchmarks import BenchmarkRegistry
from superbench.common.utils import network
def clean_simulated_ddp_distributed_env():
"""Function to clean up the simulated DDP distributed envionment variables."""
os.environ.pop('WORLD_SIZE')
os.environ.pop('RANK')
os.environ.pop('LOCAL_RANK')
os.environ.pop('MASTER_ADDR')
os.environ.pop('MASTER_PORT')
def setup_simulated_ddp_distributed_env(world_size, local_rank, port):
"""Function to setup the simulated DDP distributed envionment variables."""
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['RANK'] = str(local_rank)
os.environ['LOCAL_RANK'] = str(local_rank)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = str(port)
def benchmark_in_one_process(context, world_size, local_rank, port, queue):
"""Function to setup env for DDP initialization and run the benchmark in each single process."""
setup_simulated_ddp_distributed_env(world_size, local_rank, port)
benchmark = BenchmarkRegistry.launch_benchmark(context)
# parser object must be removed becaues it can not be serialized.
benchmark._parser = None
queue.put(benchmark)
clean_simulated_ddp_distributed_env()
def simulated_ddp_distributed_benchmark(context, world_size):
"""Function to run the benchmark on #world_size number of processes.
Return:
results (list): list of benchmark results from #world_size number of processes.
"""
port = network.get_free_port()
if not port:
return None
process_list = []
multiprocessing.set_start_method('spawn')
queue = multiprocessing.Queue()
for rank in range(world_size):
process = Process(target=benchmark_in_one_process, args=(context, world_size, rank, port, queue))
process.start()
process_list.append(process)
for process in process_list:
process.join()
results = [queue.get(1) for p in process_list]
return results